From 00a3e51a93dd8f3482926b0679f393e739de0913 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Fri, 9 Feb 2024 16:32:52 +0000 Subject: [PATCH] Extends OllamaAPI by Chat methods and refactors synchronous Generate API Methods --- .../ollama4j/core/OllamaAPI.java | 79 ++++++------------- 1 file changed, 25 insertions(+), 54 deletions(-) diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 17fddc0..831e09e 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -2,10 +2,15 @@ package io.github.amithkoujalgi.ollama4j.core; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.*; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest; +import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaChatRequestCaller; +import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateRequestCaller; import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import java.io.BufferedReader; @@ -343,7 +348,7 @@ public class OllamaAPI { throws OllamaBaseException, IOException, InterruptedException { OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSync(ollamaRequestModel); + return generateSyncForOllamaRequestModel(ollamaRequestModel); } /** @@ -387,7 +392,7 @@ public class OllamaAPI { } OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSync(ollamaRequestModel); + return generateSyncForOllamaRequestModel(ollamaRequestModel); } /** @@ -411,9 +416,23 @@ public class OllamaAPI { } OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSync(ollamaRequestModel); + return generateSyncForOllamaRequestModel(ollamaRequestModel); } + + + public OllamaResult chat(String model, List messages) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException{ + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model); + return chat(builder.withMessages(messages).build()); + } + + public OllamaResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException{ + OllamaChatRequestCaller requestCaller = new OllamaChatRequestCaller(host, basicAuth, requestTimeoutSeconds, verbose); + return requestCaller.generateSync(request); + } + + // technical private methods // + private static String encodeFileToBase64(File file) throws IOException { return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath())); } @@ -436,58 +455,10 @@ public class OllamaAPI { } } - private OllamaResult generateSync(OllamaRequestModel ollamaRequestModel) + private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel) throws OllamaBaseException, IOException, InterruptedException { - long startTime = System.currentTimeMillis(); - HttpClient httpClient = HttpClient.newHttpClient(); - URI uri = URI.create(this.host + "/api/generate"); - HttpRequest.Builder requestBuilder = - getRequestBuilderDefault(uri) - .POST( - HttpRequest.BodyPublishers.ofString( - Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))); - HttpRequest request = requestBuilder.build(); - if (verbose) logger.info("Asking model: " + ollamaRequestModel); - HttpResponse response = - httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); - int statusCode = response.statusCode(); - InputStream responseBodyStream = response.body(); - StringBuilder responseBuffer = new StringBuilder(); - try (BufferedReader reader = - new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { - String line; - while ((line = reader.readLine()) != null) { - if (statusCode == 404) { - logger.warn("Status code: 404 (Not Found)"); - OllamaErrorResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); - responseBuffer.append(ollamaResponseModel.getError()); - } else if (statusCode == 401) { - logger.warn("Status code: 401 (Unauthorized)"); - OllamaErrorResponseModel ollamaResponseModel = - Utils.getObjectMapper() - .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); - responseBuffer.append(ollamaResponseModel.getError()); - } else { - OllamaResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); - if (!ollamaResponseModel.isDone()) { - responseBuffer.append(ollamaResponseModel.getResponse()); - } - } - } - } - - if (statusCode != 200) { - logger.error("Status code " + statusCode); - throw new OllamaBaseException(responseBuffer.toString()); - } else { - long endTime = System.currentTimeMillis(); - OllamaResult ollamaResult = - new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); - if (verbose) logger.info("Model response: " + ollamaResult); - return ollamaResult; - } + OllamaGenerateRequestCaller requestCaller = new OllamaGenerateRequestCaller(host, basicAuth, requestTimeoutSeconds, verbose); + return requestCaller.generateSync(ollamaRequestModel); } /**