diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java
index cbde59e..77d6e62 100644
--- a/src/main/java/io/github/ollama4j/OllamaAPI.java
+++ b/src/main/java/io/github/ollama4j/OllamaAPI.java
@@ -11,6 +11,7 @@ import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
import io.github.ollama4j.models.generate.OllamaStreamHandler;
+import io.github.ollama4j.models.generate.OllamaTokenHandler;
import io.github.ollama4j.models.ps.ModelsProcessResponse;
import io.github.ollama4j.models.request.*;
import io.github.ollama4j.models.response.*;
@@ -785,15 +786,34 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
*/
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
+ return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
+ }
+
+ /**
+ * Ask a question to a model using an {@link OllamaChatRequest}. This can be constructed using an {@link OllamaChatRequestBuilder}.
+ *
+ * Hint: the OllamaChatRequestModel#getStream() property is not implemented.
+ *
+ * @param request request object to be sent to the server
+ * @param tokenHandler callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated)
+ * @return {@link OllamaChatResult}
+ * @throws OllamaBaseException any response code than 200 has been returned
+ * @throws IOException in case the responseStream can not be read
+ * @throws InterruptedException in case the server is not reachable or network issues happen
+ * @throws OllamaBaseException if the response indicates an error status
+ * @throws IOException if an I/O error occurs during the HTTP request
+ * @throws InterruptedException if the operation is interrupted
+ */
+ public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
OllamaChatResult result;
// add all registered tools to Request
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
- if (streamHandler != null) {
+ if (tokenHandler != null) {
request.setStream(true);
- result = requestCaller.call(request, streamHandler);
+ result = requestCaller.call(request, tokenHandler);
} else {
result = requestCaller.callSync(request);
}
@@ -810,8 +830,8 @@ public class OllamaAPI {
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
}
- if (streamHandler != null) {
- result = requestCaller.call(request, streamHandler);
+ if (tokenHandler != null) {
+ result = requestCaller.call(request, tokenHandler);
} else {
result = requestCaller.callSync(request);
}
diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java
index 9f1bf7f..af181da 100644
--- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java
+++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java
@@ -1,31 +1,19 @@
package io.github.ollama4j.models.chat;
import io.github.ollama4j.models.generate.OllamaStreamHandler;
+import io.github.ollama4j.models.generate.OllamaTokenHandler;
+import lombok.RequiredArgsConstructor;
-import java.util.ArrayList;
-import java.util.List;
-
-public class OllamaChatStreamObserver {
-
- private OllamaStreamHandler streamHandler;
-
- private List responseParts = new ArrayList<>();
-
+@RequiredArgsConstructor
+public class OllamaChatStreamObserver implements OllamaTokenHandler {
+ private final OllamaStreamHandler streamHandler;
private String message = "";
- public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) {
- this.streamHandler = streamHandler;
+ @Override
+ public void accept(OllamaChatResponseModel token) {
+ if (streamHandler != null) {
+ message += token.getMessage().getContent();
+ streamHandler.accept(message);
+ }
}
-
- public void notify(OllamaChatResponseModel currentResponsePart) {
- responseParts.add(currentResponsePart);
- handleCurrentResponsePart(currentResponsePart);
- }
-
- protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart) {
- message = message + currentResponsePart.getMessage().getContent();
- streamHandler.accept(message);
- }
-
-
}
diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaTokenHandler.java b/src/main/java/io/github/ollama4j/models/generate/OllamaTokenHandler.java
new file mode 100644
index 0000000..a0aed8c
--- /dev/null
+++ b/src/main/java/io/github/ollama4j/models/generate/OllamaTokenHandler.java
@@ -0,0 +1,8 @@
+package io.github.ollama4j.models.generate;
+
+import io.github.ollama4j.models.chat.OllamaChatResponseModel;
+
+import java.util.function.Consumer;
+
+public interface OllamaTokenHandler extends Consumer {
+}
diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java
index 57c9ee3..a1a6216 100644
--- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java
+++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java
@@ -4,9 +4,8 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.chat.*;
+import io.github.ollama4j.models.generate.OllamaTokenHandler;
import io.github.ollama4j.models.response.OllamaErrorResponse;
-import io.github.ollama4j.models.generate.OllamaStreamHandler;
-import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -29,7 +28,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
- private OllamaChatStreamObserver streamObserver;
+ private OllamaTokenHandler tokenHandler;
public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose);
@@ -60,8 +59,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
OllamaChatMessage message = ollamaResponseModel.getMessage();
if(message != null) {
responseBuffer.append(message.getContent());
- if (streamObserver != null) {
- streamObserver.notify(ollamaResponseModel);
+ if (tokenHandler != null) {
+ tokenHandler.accept(ollamaResponseModel);
}
}
return ollamaResponseModel.isDone();
@@ -71,9 +70,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
}
}
- public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler)
+ public OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler)
throws OllamaBaseException, IOException, InterruptedException {
- streamObserver = new OllamaChatStreamObserver(streamHandler);
+ this.tokenHandler = tokenHandler;
return callSync(body);
}
@@ -86,7 +85,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
- if (isVerbose()) LOG.info("Asking model: " + body.toString());
+ if (isVerbose()) LOG.info("Asking model: " + body);
HttpResponse response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
diff --git a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java
index 835fa76..a64fb70 100644
--- a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java
+++ b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java
@@ -321,7 +321,7 @@ class TestRealAPIs {
assertEquals(1, function.getArguments().size());
Object noOfDigits = function.getArguments().get("noOfDigits");
assertNotNull(noOfDigits);
- assertEquals("5",noOfDigits);
+ assertEquals("5", noOfDigits.toString());
assertTrue(chatResult.getChatHistory().size()>2);
List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
assertNull(finalToolCalls);