diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 46fd5fb..e48add1 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -448,12 +448,31 @@ public class OllamaAPI { * @throws InterruptedException in case the server is not reachable or network issues happen */ public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{ + return chat(request); + } + + /** + * Ask a question to a model using an {@link OllamaChatRequestModel}. This can be constructed using an {@link OllamaChatRequestBuilder}. + * + * Hint: the OllamaChatRequestModel#getStream() property is not implemented. + * + * @param request request object to be sent to the server + * @param streamHandler callback handler to handle the last message from stream (caution: all previous messages from stream will be concatenated) + * @return + * @throws OllamaBaseException any response code than 200 has been returned + * @throws IOException in case the responseStream can not be read + * @throws InterruptedException in case the server is not reachable or network issues happen + */ + public OllamaChatResult chat(OllamaChatRequestModel request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException{ OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); - //TODO: implement async way - if(request.isStream()){ - throw new UnsupportedOperationException("Streamed chat responses are not implemented yet"); + OllamaResult result; + if(streamHandler != null){ + request.setStream(true); + result = requestCaller.call(request, streamHandler); + } + else { + result = requestCaller.callSync(request); } - OllamaResult result = requestCaller.generateSync(request); return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); } @@ -470,7 +489,7 @@ public class OllamaAPI { private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); - return requestCaller.generateSync(ollamaRequestModel); + return requestCaller.callSync(ollamaRequestModel); } /** diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java index eb06c37..811ef11 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java @@ -1,12 +1,19 @@ package io.github.amithkoujalgi.ollama4j.core.models.request; +import java.io.IOException; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.fasterxml.jackson.core.JsonProcessingException; +import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler; +import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatStreamObserver; +import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; /** @@ -16,6 +23,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{ private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); + private OllamaChatStreamObserver streamObserver; + public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { super(host, basicAuth, requestTimeoutSeconds, verbose); } @@ -27,18 +36,25 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{ @Override protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { - try { - OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); - responseBuffer.append(ollamaResponseModel.getMessage().getContent()); - return ollamaResponseModel.isDone(); - } catch (JsonProcessingException e) { - LOG.error("Error parsing the Ollama chat response!",e); - return true; - } + try { + OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); + responseBuffer.append(ollamaResponseModel.getMessage().getContent()); + if(streamObserver != null) { + streamObserver.notify(ollamaResponseModel); + } + return ollamaResponseModel.isDone(); + } catch (JsonProcessingException e) { + LOG.error("Error parsing the Ollama chat response!",e); + return true; + } } - + public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) + throws OllamaBaseException, IOException, InterruptedException { + streamObserver = new OllamaChatStreamObserver(streamHandler); + return super.callSync(body); + } - + } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java index 93b2b2f..ad8d5bb 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java @@ -46,7 +46,7 @@ public abstract class OllamaEndpointCaller { protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); - + /** * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response. * @@ -56,7 +56,7 @@ public abstract class OllamaEndpointCaller { * @throws IOException in case the responseStream can not be read * @throws InterruptedException in case the server is not reachable or network issues happen */ - public OllamaResult generateSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{ + public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{ // Create Request long startTime = System.currentTimeMillis();