From cf4e7a96e8090d3288fbb686aaf259156f6f7234 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Fri, 16 Feb 2024 16:31:39 +0000 Subject: [PATCH 1/5] Optimizes ChatStreamObserver to use only the last message instead of parsing all messages again --- .../core/models/chat/OllamaChatStreamObserver.java | 7 ++----- .../ollama4j/integrationtests/TestRealAPIs.java | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatStreamObserver.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatStreamObserver.java index 6a782f4..ea4b4d8 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatStreamObserver.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatStreamObserver.java @@ -2,10 +2,8 @@ package io.github.amithkoujalgi.ollama4j.core.models.chat; import java.util.ArrayList; import java.util.List; -import java.util.stream.Collectors; import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler; -import lombok.NonNull; public class OllamaChatStreamObserver { @@ -13,7 +11,7 @@ public class OllamaChatStreamObserver { private List responseParts = new ArrayList<>(); - private String message; + private String message = ""; public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) { this.streamHandler = streamHandler; @@ -25,8 +23,7 @@ public class OllamaChatStreamObserver { } protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart){ - List<@NonNull String> allResponsePartsByNow = responseParts.stream().map(r -> r.getMessage().getContent()).collect(Collectors.toList()); - message = String.join("", allResponsePartsByNow); + message = message + currentResponsePart.getMessage().getContent(); streamHandler.accept(message); } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java index 870e17f..b4a7e1d 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java @@ -183,7 +183,7 @@ class TestRealAPIs { OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> { LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()-1); + String substring = s.substring(sb.toString().length(), s.length()); LOG.info(substring); sb.append(substring); }); From abb76ad8671a293fe87e443ac7099489e9506540 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Fri, 16 Feb 2024 17:03:15 +0000 Subject: [PATCH 2/5] Adds streaming feature to Generate APIs --- .../ollama4j/core/OllamaAPI.java | 62 ++++++++++++++++--- .../OllamaGenerateStreamObserver.java | 31 ++++++++++ .../request/OllamaGenerateEndpointCaller.java | 18 +++++- .../integrationtests/TestRealAPIs.java | 50 +++++++++++++++ 4 files changed, 150 insertions(+), 11 deletions(-) create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 20e9d3e..ec772f1 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -342,13 +342,24 @@ public class OllamaAPI { * @param options the Options object - More * details on the options + * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false. * @return OllamaResult that includes response text and time taken for response */ - public OllamaResult generate(String model, String prompt, Options options) + public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSyncForOllamaRequestModel(ollamaRequestModel); + return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler); + } + + /** + * Convenience method to call Ollama API without streaming responses. + * + * Uses {@link #generate(String, String, Options, OllamaStreamHandler)} + */ + public OllamaResult generate(String model, String prompt, Options options) + throws OllamaBaseException, IOException, InterruptedException { + return generate(model, prompt, options,null); } /** @@ -381,10 +392,11 @@ public class OllamaAPI { * @param options the Options object - More * details on the options + * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false. * @return OllamaResult that includes response text and time taken for response */ public OllamaResult generateWithImageFiles( - String model, String prompt, List imageFiles, Options options) + String model, String prompt, List imageFiles, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { List images = new ArrayList<>(); for (File imageFile : imageFiles) { @@ -392,9 +404,20 @@ public class OllamaAPI { } OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSyncForOllamaRequestModel(ollamaRequestModel); + return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler); } + /** + * Convenience method to call Ollama API without streaming responses. + * + * Uses {@link #generateWithImageFiles(String, String, List, Options, OllamaStreamHandler)} + */ + public OllamaResult generateWithImageFiles( + String model, String prompt, List imageFiles, Options options) + throws OllamaBaseException, IOException, InterruptedException{ + return generateWithImageFiles(model, prompt, imageFiles, options, null); +} + /** * With one or more image URLs, ask a question to a model running on Ollama server. This is a * sync/blocking call. @@ -405,10 +428,11 @@ public class OllamaAPI { * @param options the Options object - More * details on the options + * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false. * @return OllamaResult that includes response text and time taken for response */ public OllamaResult generateWithImageURLs( - String model, String prompt, List imageURLs, Options options) + String model, String prompt, List imageURLs, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { List images = new ArrayList<>(); for (String imageURL : imageURLs) { @@ -416,7 +440,18 @@ public class OllamaAPI { } OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSyncForOllamaRequestModel(ollamaRequestModel); + return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler); + } + + /** + * Convenience method to call Ollama API without streaming responses. + * + * Uses {@link #generateWithImageURLs(String, String, List, Options, OllamaStreamHandler)} + */ + public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, + Options options) + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + return generateWithImageURLs(model, prompt, imageURLs, options, null); } @@ -487,10 +522,19 @@ public class OllamaAPI { return Base64.getEncoder().encodeToString(bytes); } - private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel) + private OllamaResult generateSyncForOllamaRequestModel( + OllamaGenerateRequestModel ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { - OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); - return requestCaller.callSync(ollamaRequestModel); + OllamaGenerateEndpointCaller requestCaller = + new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); + OllamaResult result; + if (streamHandler != null) { + ollamaRequestModel.setStream(true); + result = requestCaller.call(ollamaRequestModel, streamHandler); + } else { + result = requestCaller.callSync(ollamaRequestModel); + } + return result; } /** diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java new file mode 100644 index 0000000..a166bac --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java @@ -0,0 +1,31 @@ +package io.github.amithkoujalgi.ollama4j.core.models.generate; + +import java.util.ArrayList; +import java.util.List; + +import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler; + +public class OllamaGenerateStreamObserver { + + private OllamaStreamHandler streamHandler; + + private List responseParts = new ArrayList<>(); + + private String message = ""; + + public OllamaGenerateStreamObserver(OllamaStreamHandler streamHandler) { + this.streamHandler = streamHandler; + } + + public void notify(OllamaGenerateResponseModel currentResponsePart){ + responseParts.add(currentResponsePart); + handleCurrentResponsePart(currentResponsePart); + } + + protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart){ + message = message + currentResponsePart.getResponse(); + streamHandler.accept(message); + } + + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java index ba55159..fe7fbec 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java @@ -1,18 +1,25 @@ package io.github.amithkoujalgi.ollama4j.core.models.request; +import java.io.IOException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.fasterxml.jackson.core.JsonProcessingException; - +import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler; +import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver; +import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class); + private OllamaGenerateStreamObserver streamObserver; + public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { super(host, basicAuth, requestTimeoutSeconds, verbose); } @@ -27,6 +34,9 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ try { OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); responseBuffer.append(ollamaResponseModel.getResponse()); + if(streamObserver != null) { + streamObserver.notify(ollamaResponseModel); + } return ollamaResponseModel.isDone(); } catch (JsonProcessingException e) { LOG.error("Error parsing the Ollama chat response!",e); @@ -34,7 +44,11 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ } } - + public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) + throws OllamaBaseException, IOException, InterruptedException { + streamObserver = new OllamaGenerateStreamObserver(streamHandler); + return super.callSync(body); + } } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java index b4a7e1d..c124f42 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java @@ -109,6 +109,32 @@ class TestRealAPIs { } } + @Test + @Order(3) + void testAskModelWithDefaultOptionsStreamed() { + testEndpointReachability(); + try { + + StringBuffer sb = new StringBuffer(""); + + OllamaResult result = ollamaAPI.generate(config.getModel(), + "What is the capital of France? And what's France's connection with Mona Lisa?", + new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); + + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertEquals(sb.toString().trim(), result.getResponse().trim()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + @Test @Order(3) void testAskModelWithOptions() { @@ -262,6 +288,30 @@ class TestRealAPIs { } } + @Test + @Order(3) + void testAskModelWithOptionsAndImageFilesStreamed() { + testEndpointReachability(); + File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); + try { + StringBuffer sb = new StringBuffer(""); + + OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(), + "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertEquals(sb.toString().trim(), result.getResponse().trim()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + @Test @Order(3) void testAskModelWithOptionsAndImageURLs() { From 1e66bdb07ff8cad48c800d5123d4d8c2d53e8648 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Sun, 18 Feb 2024 21:41:35 +0000 Subject: [PATCH 3/5] Adds documentation for streamed generate API call --- docs/docs/apis-ask/ask.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/docs/docs/apis-ask/ask.md b/docs/docs/apis-ask/ask.md index f62437c..6d1cf66 100644 --- a/docs/docs/apis-ask/ask.md +++ b/docs/docs/apis-ask/ask.md @@ -41,6 +41,41 @@ You will get a response similar to: > require > natural language understanding and generation capabilities. +## Try asking a question, receiving the answer streamed + +```java +public class Main { + + public static void main(String[] args) { + + String host = "http://localhost:11434/"; + + OllamaAPI ollamaAPI = new OllamaAPI(host); + // define a stream handler (Consumer) + OllamaStreamHandler streamHandler = (s) -> { + System.out.println(s); + }; + + // Should be called using seperate thread to gain non blocking streaming effect. + OllamaResult result = ollamaAPI.generate(config.getModel(), + "What is the capital of France? And what's France's connection with Mona Lisa?", + new OptionsBuilder().build(), streamHandler); + + System.out.println("Full response: " +result.getResponse()); + } +} +``` +You will get a response similar to: + +> The +> The capital +> The capital of +> The capital of France +> The capital of France is +> The capital of France is Paris +> The capital of France is Paris. +> Full response: The capital of France is Paris. + ## Try asking a question from general topics. ```java From 09442d37a317b871a199dca3c8bb664736cf9fbf Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Sun, 18 Feb 2024 22:53:01 +0000 Subject: [PATCH 4/5] Fixes unmarshalling exception on ModelDetail --- .../ollama4j/core/models/ModelDetail.java | 14 ++++++++++++-- .../ollama4j/integrationtests/TestRealAPIs.java | 14 ++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java index fa557c6..e81a20e 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java @@ -2,7 +2,8 @@ package io.github.amithkoujalgi.ollama4j.core.models; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; -import java.util.Map; +import com.fasterxml.jackson.core.JsonProcessingException; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import lombok.Data; @Data @@ -16,5 +17,14 @@ public class ModelDetail { private String parameters; private String template; private String system; - private Map details; + private ModelMeta details; + + @Override + public String toString() { + try { + return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java index c124f42..dc91287 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java @@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.*; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; +import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; @@ -91,6 +92,19 @@ class TestRealAPIs { } } + @Test + @Order(3) + void testListDtails() { + testEndpointReachability(); + try { + ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel()); + assertNotNull(modelDetails); + System.out.println(modelDetails); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + throw new RuntimeException(e); + } + } + @Test @Order(3) void testAskModelWithDefaultOptions() { From 13b7111a425135b6f039531204dbd2f2fc600302 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Sun, 18 Feb 2024 22:53:34 +0000 Subject: [PATCH 5/5] Adds toString implementation for Model and ModelMeta to be json represented --- .../amithkoujalgi/ollama4j/core/models/Model.java | 11 +++++++++++ .../amithkoujalgi/ollama4j/core/models/ModelMeta.java | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java index 44a3b87..27fd3e5 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java @@ -1,6 +1,8 @@ package io.github.amithkoujalgi.ollama4j.core.models; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import lombok.Data; @Data @@ -34,4 +36,13 @@ public class Model { return name.split(":")[1]; } + @Override + public String toString() { + try { + return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelMeta.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelMeta.java index eff4609..e534832 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelMeta.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelMeta.java @@ -2,6 +2,8 @@ package io.github.amithkoujalgi.ollama4j.core.models; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import lombok.Data; @Data @@ -21,4 +23,13 @@ public class ModelMeta { @JsonProperty("quantization_level") private String quantizationLevel; + + @Override + public String toString() { + try { + return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } }