From 6678cd3f6986442901f2a36df1873997b282cd8f Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Thu, 9 Nov 2023 12:56:45 +0530 Subject: [PATCH] - replaced GSON with Jackson - Updated readme - general cleanup --- README.md | 22 +++++++++++++++ pom.xml | 6 ++--- .../ollama4j/core/OllamaAPI.java | 13 ++++----- .../{Models.java => ListModelsResponse.java} | 2 +- .../ollama4j/core/models/ModelDetail.java | 12 +++++++-- .../models/OllamaAsyncResultCallback.java | 27 +++++++++++++------ .../core/models/OllamaRequestModel.java | 13 +++++++-- 7 files changed, 73 insertions(+), 22 deletions(-) rename src/main/java/io/github/amithkoujalgi/ollama4j/core/models/{Models.java => ListModelsResponse.java} (88%) diff --git a/README.md b/README.md index 0d28330..325d232 100644 --- a/README.md +++ b/README.md @@ -279,6 +279,28 @@ FROM sales GROUP BY customers.name; ``` +#### Async API with streaming response + +```java +public class Main { + public static void main(String[] args) throws Exception { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + + String prompt = "List all cricket world cup teams of 2019."; + OllamaAsyncResultCallback callback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, prompt); + while (!callback.isComplete() || !callback.getStream().isEmpty()) { + // poll for data from the response stream + String response = callback.getStream().poll(); + if (response != null) { + System.out.print(response); + } + Thread.sleep(1000); + } + } +} +``` + #### API Spec Find the full `Javadoc` (API specifications) [here](https://amithkoujalgi.github.io/ollama4j/). diff --git a/pom.xml b/pom.xml index 320f865..b1ec51b 100644 --- a/pom.xml +++ b/pom.xml @@ -102,9 +102,9 @@ - com.google.code.gson - gson - 2.10.1 + com.fasterxml.jackson.core + jackson-databind + 2.15.3 ch.qos.logback diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 5b1c4c5..695c52e 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -1,6 +1,6 @@ package io.github.amithkoujalgi.ollama4j.core; -import com.google.gson.Gson; +import com.fasterxml.jackson.databind.ObjectMapper; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.*; import org.apache.hc.client5.http.classic.methods.HttpDelete; @@ -30,10 +30,13 @@ import java.util.stream.Collectors; */ @SuppressWarnings({"DuplicatedCode", "ExtractMethodRecommender"}) public class OllamaAPI { + private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); private final String host; private boolean verbose = false; + private final ObjectMapper objectMapper = new ObjectMapper(); + /** * Instantiates the Ollama API. * @@ -76,8 +79,7 @@ public class OllamaAPI { responseString = EntityUtils.toString(responseEntity, "UTF-8"); } if (statusCode == 200) { - Models m = new Gson().fromJson(responseString, Models.class); - return m.getModels(); + return objectMapper.readValue(responseString, ListModelsResponse.class).getModels(); } else { throw new OllamaBaseException(statusCode + " - " + responseString); } @@ -109,7 +111,7 @@ public class OllamaAPI { responseString = EntityUtils.toString(responseEntity, "UTF-8"); } if (statusCode == 200) { - return new Gson().fromJson(responseString, ModelDetail.class); + return objectMapper.readValue(responseString, ModelDetail.class); } else { throw new OllamaBaseException(statusCode + " - " + responseString); } @@ -234,7 +236,6 @@ public class OllamaAPI { * @throws IOException */ public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException { - Gson gson = new Gson(); OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); URL obj = new URL(this.host + "/api/generate"); HttpURLConnection con = (HttpURLConnection) obj.openConnection(); @@ -250,7 +251,7 @@ public class OllamaAPI { String inputLine; StringBuilder response = new StringBuilder(); while ((inputLine = in.readLine()) != null) { - OllamaResponseModel ollamaResponseModel = gson.fromJson(inputLine, OllamaResponseModel.class); + OllamaResponseModel ollamaResponseModel = objectMapper.readValue(inputLine, OllamaResponseModel.class); if (!ollamaResponseModel.getDone()) { response.append(ollamaResponseModel.getResponse()); } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Models.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ListModelsResponse.java similarity index 88% rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Models.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ListModelsResponse.java index 8520d90..15f8495 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Models.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ListModelsResponse.java @@ -2,7 +2,7 @@ package io.github.amithkoujalgi.ollama4j.core.models; import java.util.List; -public class Models { +public class ListModelsResponse { private List models; public List getModels() { diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java index 2408012..e778f67 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java @@ -1,6 +1,7 @@ package io.github.amithkoujalgi.ollama4j.core.models; -import com.google.gson.GsonBuilder; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; public class ModelDetail { private String license, modelfile, parameters, template; @@ -39,6 +40,13 @@ public class ModelDetail { @Override public String toString() { - return new GsonBuilder().setPrettyPrinting().create().toJson(this); + try { + return new ObjectMapper() + .writer() + .withDefaultPrettyPrinter() + .writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java index 536dbe7..3412a7b 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java @@ -1,38 +1,42 @@ package io.github.amithkoujalgi.ollama4j.core.models; -import com.google.gson.Gson; +import com.fasterxml.jackson.databind.ObjectMapper; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.net.HttpURLConnection; +import java.util.LinkedList; +import java.util.Queue; +@SuppressWarnings("DuplicatedCode") public class OllamaAsyncResultCallback extends Thread { private final HttpURLConnection connection; private String result; private boolean isDone; + private final ObjectMapper objectMapper = new ObjectMapper(); + private final Queue queue = new LinkedList<>(); public OllamaAsyncResultCallback(HttpURLConnection connection) { this.connection = connection; this.isDone = false; this.result = ""; + this.queue.add(""); } @Override public void run() { - Gson gson = new Gson(); int responseCode = 0; try { responseCode = this.connection.getResponseCode(); if (responseCode == HttpURLConnection.HTTP_OK) { - try (BufferedReader in = - new BufferedReader(new InputStreamReader(this.connection.getInputStream()))) { + try (BufferedReader in = new BufferedReader(new InputStreamReader(this.connection.getInputStream()))) { String inputLine; StringBuilder response = new StringBuilder(); while ((inputLine = in.readLine()) != null) { - OllamaResponseModel ollamaResponseModel = - gson.fromJson(inputLine, OllamaResponseModel.class); + OllamaResponseModel ollamaResponseModel = objectMapper.readValue(inputLine, OllamaResponseModel.class); + queue.add(ollamaResponseModel.getResponse()); if (!ollamaResponseModel.getDone()) { response.append(ollamaResponseModel.getResponse()); } @@ -42,8 +46,7 @@ public class OllamaAsyncResultCallback extends Thread { this.result = response.toString(); } } else { - throw new OllamaBaseException( - connection.getResponseCode() + " - " + connection.getResponseMessage()); + throw new OllamaBaseException(connection.getResponseCode() + " - " + connection.getResponseMessage()); } } catch (IOException | OllamaBaseException e) { this.isDone = true; @@ -55,7 +58,15 @@ public class OllamaAsyncResultCallback extends Thread { return isDone; } + /** + * Returns the final response when the execution completes. Does not return intermediate results. + * @return response text + */ public String getResponse() { return result; } + + public Queue getStream() { + return queue; + } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java index 0f1454f..1c4bcb4 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java @@ -1,6 +1,8 @@ package io.github.amithkoujalgi.ollama4j.core.models; -import com.google.gson.Gson; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; public class OllamaRequestModel { private String model; @@ -29,6 +31,13 @@ public class OllamaRequestModel { @Override public String toString() { - return new Gson().toJson(this); + try { + return new ObjectMapper() + .writer() + .withDefaultPrettyPrinter() + .writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } } }