diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index cbde59e..77d6e62 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -11,6 +11,7 @@ import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateRequest; import io.github.ollama4j.models.generate.OllamaStreamHandler; +import io.github.ollama4j.models.generate.OllamaTokenHandler; import io.github.ollama4j.models.ps.ModelsProcessResponse; import io.github.ollama4j.models.request.*; import io.github.ollama4j.models.response.*; @@ -785,15 +786,34 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted */ public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + return chatStreaming(request, new OllamaChatStreamObserver(streamHandler)); + } + + /** + * Ask a question to a model using an {@link OllamaChatRequest}. This can be constructed using an {@link OllamaChatRequestBuilder}. + *

+ * Hint: the OllamaChatRequestModel#getStream() property is not implemented. + * + * @param request request object to be sent to the server + * @param tokenHandler callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated) + * @return {@link OllamaChatResult} + * @throws OllamaBaseException any response code than 200 has been returned + * @throws IOException in case the responseStream can not be read + * @throws InterruptedException in case the server is not reachable or network issues happen + * @throws OllamaBaseException if the response indicates an error status + * @throws IOException if an I/O error occurs during the HTTP request + * @throws InterruptedException if the operation is interrupted + */ + public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaChatResult result; // add all registered tools to Request request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList())); - if (streamHandler != null) { + if (tokenHandler != null) { request.setStream(true); - result = requestCaller.call(request, streamHandler); + result = requestCaller.call(request, tokenHandler); } else { result = requestCaller.callSync(request); } @@ -810,8 +830,8 @@ public class OllamaAPI { request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]")); } - if (streamHandler != null) { - result = requestCaller.call(request, streamHandler); + if (tokenHandler != null) { + result = requestCaller.call(request, tokenHandler); } else { result = requestCaller.callSync(request); } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java index 9f1bf7f..af181da 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java @@ -1,31 +1,19 @@ package io.github.ollama4j.models.chat; import io.github.ollama4j.models.generate.OllamaStreamHandler; +import io.github.ollama4j.models.generate.OllamaTokenHandler; +import lombok.RequiredArgsConstructor; -import java.util.ArrayList; -import java.util.List; - -public class OllamaChatStreamObserver { - - private OllamaStreamHandler streamHandler; - - private List responseParts = new ArrayList<>(); - +@RequiredArgsConstructor +public class OllamaChatStreamObserver implements OllamaTokenHandler { + private final OllamaStreamHandler streamHandler; private String message = ""; - public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) { - this.streamHandler = streamHandler; + @Override + public void accept(OllamaChatResponseModel token) { + if (streamHandler != null) { + message += token.getMessage().getContent(); + streamHandler.accept(message); + } } - - public void notify(OllamaChatResponseModel currentResponsePart) { - responseParts.add(currentResponsePart); - handleCurrentResponsePart(currentResponsePart); - } - - protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart) { - message = message + currentResponsePart.getMessage().getContent(); - streamHandler.accept(message); - } - - } diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaTokenHandler.java b/src/main/java/io/github/ollama4j/models/generate/OllamaTokenHandler.java new file mode 100644 index 0000000..a0aed8c --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/generate/OllamaTokenHandler.java @@ -0,0 +1,8 @@ +package io.github.ollama4j.models.generate; + +import io.github.ollama4j.models.chat.OllamaChatResponseModel; + +import java.util.function.Consumer; + +public interface OllamaTokenHandler extends Consumer { +} diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java index 57c9ee3..a1a6216 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java @@ -4,9 +4,8 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.models.chat.*; +import io.github.ollama4j.models.generate.OllamaTokenHandler; import io.github.ollama4j.models.response.OllamaErrorResponse; -import io.github.ollama4j.models.generate.OllamaStreamHandler; -import io.github.ollama4j.tools.Tools; import io.github.ollama4j.utils.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -29,7 +28,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); - private OllamaChatStreamObserver streamObserver; + private OllamaTokenHandler tokenHandler; public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { super(host, basicAuth, requestTimeoutSeconds, verbose); @@ -60,8 +59,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { OllamaChatMessage message = ollamaResponseModel.getMessage(); if(message != null) { responseBuffer.append(message.getContent()); - if (streamObserver != null) { - streamObserver.notify(ollamaResponseModel); + if (tokenHandler != null) { + tokenHandler.accept(ollamaResponseModel); } } return ollamaResponseModel.isDone(); @@ -71,9 +70,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { } } - public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler) + public OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException { - streamObserver = new OllamaChatStreamObserver(streamHandler); + this.tokenHandler = tokenHandler; return callSync(body); } @@ -86,7 +85,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { .POST( body.getBodyPublisher()); HttpRequest request = requestBuilder.build(); - if (isVerbose()) LOG.info("Asking model: " + body.toString()); + if (isVerbose()) LOG.info("Asking model: " + body); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); diff --git a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java index 835fa76..a64fb70 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java @@ -321,7 +321,7 @@ class TestRealAPIs { assertEquals(1, function.getArguments().size()); Object noOfDigits = function.getArguments().get("noOfDigits"); assertNotNull(noOfDigits); - assertEquals("5",noOfDigits); + assertEquals("5", noOfDigits.toString()); assertTrue(chatResult.getChatHistory().size()>2); List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); assertNull(finalToolCalls);