From abb76ad8671a293fe87e443ac7099489e9506540 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Fri, 16 Feb 2024 17:03:15 +0000 Subject: [PATCH] Adds streaming feature to Generate APIs --- .../ollama4j/core/OllamaAPI.java | 62 ++++++++++++++++--- .../OllamaGenerateStreamObserver.java | 31 ++++++++++ .../request/OllamaGenerateEndpointCaller.java | 18 +++++- .../integrationtests/TestRealAPIs.java | 50 +++++++++++++++ 4 files changed, 150 insertions(+), 11 deletions(-) create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 20e9d3e..ec772f1 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -342,13 +342,24 @@ public class OllamaAPI { * @param options the Options object - More * details on the options + * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false. * @return OllamaResult that includes response text and time taken for response */ - public OllamaResult generate(String model, String prompt, Options options) + public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSyncForOllamaRequestModel(ollamaRequestModel); + return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler); + } + + /** + * Convenience method to call Ollama API without streaming responses. + * + * Uses {@link #generate(String, String, Options, OllamaStreamHandler)} + */ + public OllamaResult generate(String model, String prompt, Options options) + throws OllamaBaseException, IOException, InterruptedException { + return generate(model, prompt, options,null); } /** @@ -381,10 +392,11 @@ public class OllamaAPI { * @param options the Options object - More * details on the options + * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false. * @return OllamaResult that includes response text and time taken for response */ public OllamaResult generateWithImageFiles( - String model, String prompt, List imageFiles, Options options) + String model, String prompt, List imageFiles, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { List images = new ArrayList<>(); for (File imageFile : imageFiles) { @@ -392,9 +404,20 @@ public class OllamaAPI { } OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSyncForOllamaRequestModel(ollamaRequestModel); + return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler); } + /** + * Convenience method to call Ollama API without streaming responses. + * + * Uses {@link #generateWithImageFiles(String, String, List, Options, OllamaStreamHandler)} + */ + public OllamaResult generateWithImageFiles( + String model, String prompt, List imageFiles, Options options) + throws OllamaBaseException, IOException, InterruptedException{ + return generateWithImageFiles(model, prompt, imageFiles, options, null); +} + /** * With one or more image URLs, ask a question to a model running on Ollama server. This is a * sync/blocking call. @@ -405,10 +428,11 @@ public class OllamaAPI { * @param options the Options object - More * details on the options + * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false. * @return OllamaResult that includes response text and time taken for response */ public OllamaResult generateWithImageURLs( - String model, String prompt, List imageURLs, Options options) + String model, String prompt, List imageURLs, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { List images = new ArrayList<>(); for (String imageURL : imageURLs) { @@ -416,7 +440,18 @@ public class OllamaAPI { } OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSyncForOllamaRequestModel(ollamaRequestModel); + return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler); + } + + /** + * Convenience method to call Ollama API without streaming responses. + * + * Uses {@link #generateWithImageURLs(String, String, List, Options, OllamaStreamHandler)} + */ + public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, + Options options) + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + return generateWithImageURLs(model, prompt, imageURLs, options, null); } @@ -487,10 +522,19 @@ public class OllamaAPI { return Base64.getEncoder().encodeToString(bytes); } - private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel) + private OllamaResult generateSyncForOllamaRequestModel( + OllamaGenerateRequestModel ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { - OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); - return requestCaller.callSync(ollamaRequestModel); + OllamaGenerateEndpointCaller requestCaller = + new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); + OllamaResult result; + if (streamHandler != null) { + ollamaRequestModel.setStream(true); + result = requestCaller.call(ollamaRequestModel, streamHandler); + } else { + result = requestCaller.callSync(ollamaRequestModel); + } + return result; } /** diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java new file mode 100644 index 0000000..a166bac --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java @@ -0,0 +1,31 @@ +package io.github.amithkoujalgi.ollama4j.core.models.generate; + +import java.util.ArrayList; +import java.util.List; + +import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler; + +public class OllamaGenerateStreamObserver { + + private OllamaStreamHandler streamHandler; + + private List responseParts = new ArrayList<>(); + + private String message = ""; + + public OllamaGenerateStreamObserver(OllamaStreamHandler streamHandler) { + this.streamHandler = streamHandler; + } + + public void notify(OllamaGenerateResponseModel currentResponsePart){ + responseParts.add(currentResponsePart); + handleCurrentResponsePart(currentResponsePart); + } + + protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart){ + message = message + currentResponsePart.getResponse(); + streamHandler.accept(message); + } + + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java index ba55159..fe7fbec 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java @@ -1,18 +1,25 @@ package io.github.amithkoujalgi.ollama4j.core.models.request; +import java.io.IOException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.fasterxml.jackson.core.JsonProcessingException; - +import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler; +import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver; +import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class); + private OllamaGenerateStreamObserver streamObserver; + public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { super(host, basicAuth, requestTimeoutSeconds, verbose); } @@ -27,6 +34,9 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ try { OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); responseBuffer.append(ollamaResponseModel.getResponse()); + if(streamObserver != null) { + streamObserver.notify(ollamaResponseModel); + } return ollamaResponseModel.isDone(); } catch (JsonProcessingException e) { LOG.error("Error parsing the Ollama chat response!",e); @@ -34,7 +44,11 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ } } - + public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) + throws OllamaBaseException, IOException, InterruptedException { + streamObserver = new OllamaGenerateStreamObserver(streamHandler); + return super.callSync(body); + } } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java index b4a7e1d..c124f42 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java @@ -109,6 +109,32 @@ class TestRealAPIs { } } + @Test + @Order(3) + void testAskModelWithDefaultOptionsStreamed() { + testEndpointReachability(); + try { + + StringBuffer sb = new StringBuffer(""); + + OllamaResult result = ollamaAPI.generate(config.getModel(), + "What is the capital of France? And what's France's connection with Mona Lisa?", + new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); + + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertEquals(sb.toString().trim(), result.getResponse().trim()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + @Test @Order(3) void testAskModelWithOptions() { @@ -262,6 +288,30 @@ class TestRealAPIs { } } + @Test + @Order(3) + void testAskModelWithOptionsAndImageFilesStreamed() { + testEndpointReachability(); + File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); + try { + StringBuffer sb = new StringBuffer(""); + + OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(), + "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertEquals(sb.toString().trim(), result.getResponse().trim()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + @Test @Order(3) void testAskModelWithOptionsAndImageURLs() {