diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 29cb467..f570992 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -11,6 +11,7 @@ import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateRequest; import io.github.ollama4j.models.generate.OllamaStreamHandler; +import io.github.ollama4j.models.generate.OllamaTokenHandler; import io.github.ollama4j.models.ps.ModelsProcessResponse; import io.github.ollama4j.models.request.*; import io.github.ollama4j.models.response.*; @@ -785,15 +786,34 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted */ public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + return chatStreaming(request, new OllamaChatStreamObserver(streamHandler)); + } + + /** + * Ask a question to a model using an {@link OllamaChatRequest}. This can be constructed using an {@link OllamaChatRequestBuilder}. + *
+ * Hint: the OllamaChatRequestModel#getStream() property is not implemented.
+ *
+ * @param request request object to be sent to the server
+ * @param tokenHandler callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated)
+ * @return {@link OllamaChatResult}
+ * @throws OllamaBaseException any response code than 200 has been returned
+ * @throws IOException in case the responseStream can not be read
+ * @throws InterruptedException in case the server is not reachable or network issues happen
+ * @throws OllamaBaseException if the response indicates an error status
+ * @throws IOException if an I/O error occurs during the HTTP request
+ * @throws InterruptedException if the operation is interrupted
+ */
+ public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
OllamaChatResult result;
// add all registered tools to Request
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
- if (streamHandler != null) {
+ if (tokenHandler != null) {
request.setStream(true);
- result = requestCaller.call(request, streamHandler);
+ result = requestCaller.call(request, tokenHandler);
} else {
result = requestCaller.callSync(request);
}
@@ -810,8 +830,8 @@ public class OllamaAPI {
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
}
- if (streamHandler != null) {
- result = requestCaller.call(request, streamHandler);
+ if (tokenHandler != null) {
+ result = requestCaller.call(request, tokenHandler);
} else {
result = requestCaller.callSync(request);
}
diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java
index 9f1bf7f..af181da 100644
--- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java
+++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java
@@ -1,31 +1,19 @@
package io.github.ollama4j.models.chat;
import io.github.ollama4j.models.generate.OllamaStreamHandler;
+import io.github.ollama4j.models.generate.OllamaTokenHandler;
+import lombok.RequiredArgsConstructor;
-import java.util.ArrayList;
-import java.util.List;
-
-public class OllamaChatStreamObserver {
-
- private OllamaStreamHandler streamHandler;
-
- private List