From c89440cbca1c048892a6249cecce6b21c6f31cde Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Tue, 13 Feb 2024 17:56:07 +0000 Subject: [PATCH 1/3] Adds OllamaStream handling --- .../ollama4j/core/OllamaStreamHandler.java | 7 ++++ .../models/chat/OllamaChatStreamObserver.java | 34 +++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaStreamHandler.java create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatStreamObserver.java diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaStreamHandler.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaStreamHandler.java new file mode 100644 index 0000000..803f393 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaStreamHandler.java @@ -0,0 +1,7 @@ +package io.github.amithkoujalgi.ollama4j.core; + +import java.util.function.Consumer; + +public interface OllamaStreamHandler extends Consumer{ + void accept(String message); +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatStreamObserver.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatStreamObserver.java new file mode 100644 index 0000000..6a782f4 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatStreamObserver.java @@ -0,0 +1,34 @@ +package io.github.amithkoujalgi.ollama4j.core.models.chat; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler; +import lombok.NonNull; + +public class OllamaChatStreamObserver { + + private OllamaStreamHandler streamHandler; + + private List responseParts = new ArrayList<>(); + + private String message; + + public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) { + this.streamHandler = streamHandler; + } + + public void notify(OllamaChatResponseModel currentResponsePart){ + responseParts.add(currentResponsePart); + handleCurrentResponsePart(currentResponsePart); + } + + protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart){ + List<@NonNull String> allResponsePartsByNow = responseParts.stream().map(r -> r.getMessage().getContent()).collect(Collectors.toList()); + message = String.join("", allResponsePartsByNow); + streamHandler.accept(message); + } + + +} From b41b62220c4cea8a68830fb4215bcf580b829fc7 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Tue, 13 Feb 2024 17:59:27 +0000 Subject: [PATCH 2/3] Adds chat with stream functionality in OllamaAPI --- .../ollama4j/core/OllamaAPI.java | 29 ++++++++++++--- .../request/OllamaChatEndpointCaller.java | 36 +++++++++++++------ .../models/request/OllamaEndpointCaller.java | 4 +-- 3 files changed, 52 insertions(+), 17 deletions(-) diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 46fd5fb..e48add1 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -448,12 +448,31 @@ public class OllamaAPI { * @throws InterruptedException in case the server is not reachable or network issues happen */ public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{ + return chat(request); + } + + /** + * Ask a question to a model using an {@link OllamaChatRequestModel}. This can be constructed using an {@link OllamaChatRequestBuilder}. + * + * Hint: the OllamaChatRequestModel#getStream() property is not implemented. + * + * @param request request object to be sent to the server + * @param streamHandler callback handler to handle the last message from stream (caution: all previous messages from stream will be concatenated) + * @return + * @throws OllamaBaseException any response code than 200 has been returned + * @throws IOException in case the responseStream can not be read + * @throws InterruptedException in case the server is not reachable or network issues happen + */ + public OllamaChatResult chat(OllamaChatRequestModel request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException{ OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); - //TODO: implement async way - if(request.isStream()){ - throw new UnsupportedOperationException("Streamed chat responses are not implemented yet"); + OllamaResult result; + if(streamHandler != null){ + request.setStream(true); + result = requestCaller.call(request, streamHandler); + } + else { + result = requestCaller.callSync(request); } - OllamaResult result = requestCaller.generateSync(request); return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); } @@ -470,7 +489,7 @@ public class OllamaAPI { private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); - return requestCaller.generateSync(ollamaRequestModel); + return requestCaller.callSync(ollamaRequestModel); } /** diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java index eb06c37..811ef11 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java @@ -1,12 +1,19 @@ package io.github.amithkoujalgi.ollama4j.core.models.request; +import java.io.IOException; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.fasterxml.jackson.core.JsonProcessingException; +import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler; +import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatStreamObserver; +import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; /** @@ -16,6 +23,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{ private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); + private OllamaChatStreamObserver streamObserver; + public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { super(host, basicAuth, requestTimeoutSeconds, verbose); } @@ -27,18 +36,25 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{ @Override protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { - try { - OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); - responseBuffer.append(ollamaResponseModel.getMessage().getContent()); - return ollamaResponseModel.isDone(); - } catch (JsonProcessingException e) { - LOG.error("Error parsing the Ollama chat response!",e); - return true; - } + try { + OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); + responseBuffer.append(ollamaResponseModel.getMessage().getContent()); + if(streamObserver != null) { + streamObserver.notify(ollamaResponseModel); + } + return ollamaResponseModel.isDone(); + } catch (JsonProcessingException e) { + LOG.error("Error parsing the Ollama chat response!",e); + return true; + } } - + public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) + throws OllamaBaseException, IOException, InterruptedException { + streamObserver = new OllamaChatStreamObserver(streamHandler); + return super.callSync(body); + } - + } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java index 93b2b2f..ad8d5bb 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java @@ -46,7 +46,7 @@ public abstract class OllamaEndpointCaller { protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); - + /** * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response. * @@ -56,7 +56,7 @@ public abstract class OllamaEndpointCaller { * @throws IOException in case the responseStream can not be read * @throws InterruptedException in case the server is not reachable or network issues happen */ - public OllamaResult generateSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{ + public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{ // Create Request long startTime = System.currentTimeMillis(); From e9621f054dc56007a08a84faed7be243d4e8c4c5 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Tue, 13 Feb 2024 18:11:59 +0000 Subject: [PATCH 3/3] Adds integration test for chat streaming API --- .../integrationtests/TestRealAPIs.java | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java index ed5c862..870e17f 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java @@ -23,8 +23,13 @@ import lombok.Data; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; class TestRealAPIs { + + private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class); + OllamaAPI ollamaAPI; Config config; @@ -164,6 +169,31 @@ class TestRealAPIs { } } + @Test + @Order(3) + void testChatWithStream() { + testEndpointReachability(); + try { + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); + OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "What is the capital of France? And what's France's connection with Mona Lisa?") + .build(); + + StringBuffer sb = new StringBuffer(""); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()-1); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(chatResult); + assertEquals(sb.toString().trim(), chatResult.getResponse().trim()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + @Test @Order(3) void testChatWithImageFromFileWithHistoryRecognition() {