Merge pull request #88 from seeseemelk/feature/token-streamer

Add ability to stream tokens in chat
This commit is contained in:
Amith Koujalgi 2025-01-26 17:37:22 +05:30 committed by GitHub
commit dda807d818
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 51 additions and 36 deletions

View File

@ -11,6 +11,7 @@ import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateRequest; import io.github.ollama4j.models.generate.OllamaGenerateRequest;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaStreamHandler;
import io.github.ollama4j.models.generate.OllamaTokenHandler;
import io.github.ollama4j.models.ps.ModelsProcessResponse; import io.github.ollama4j.models.ps.ModelsProcessResponse;
import io.github.ollama4j.models.request.*; import io.github.ollama4j.models.request.*;
import io.github.ollama4j.models.response.*; import io.github.ollama4j.models.response.*;
@ -785,15 +786,34 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
}
/**
* Ask a question to a model using an {@link OllamaChatRequest}. This can be constructed using an {@link OllamaChatRequestBuilder}.
* <p>
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
*
* @param request request object to be sent to the server
* @param tokenHandler callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated)
* @return {@link OllamaChatResult}
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
OllamaChatResult result; OllamaChatResult result;
// add all registered tools to Request // add all registered tools to Request
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList())); request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
if (streamHandler != null) { if (tokenHandler != null) {
request.setStream(true); request.setStream(true);
result = requestCaller.call(request, streamHandler); result = requestCaller.call(request, tokenHandler);
} else { } else {
result = requestCaller.callSync(request); result = requestCaller.callSync(request);
} }
@ -810,8 +830,8 @@ public class OllamaAPI {
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]")); request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
} }
if (streamHandler != null) { if (tokenHandler != null) {
result = requestCaller.call(request, streamHandler); result = requestCaller.call(request, tokenHandler);
} else { } else {
result = requestCaller.callSync(request); result = requestCaller.callSync(request);
} }

View File

@ -1,31 +1,19 @@
package io.github.ollama4j.models.chat; package io.github.ollama4j.models.chat;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaStreamHandler;
import io.github.ollama4j.models.generate.OllamaTokenHandler;
import lombok.RequiredArgsConstructor;
import java.util.ArrayList; @RequiredArgsConstructor
import java.util.List; public class OllamaChatStreamObserver implements OllamaTokenHandler {
private final OllamaStreamHandler streamHandler;
public class OllamaChatStreamObserver {
private OllamaStreamHandler streamHandler;
private List<OllamaChatResponseModel> responseParts = new ArrayList<>();
private String message = ""; private String message = "";
public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) { @Override
this.streamHandler = streamHandler; public void accept(OllamaChatResponseModel token) {
if (streamHandler != null) {
message += token.getMessage().getContent();
streamHandler.accept(message);
}
} }
public void notify(OllamaChatResponseModel currentResponsePart) {
responseParts.add(currentResponsePart);
handleCurrentResponsePart(currentResponsePart);
}
protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart) {
message = message + currentResponsePart.getMessage().getContent();
streamHandler.accept(message);
}
} }

View File

@ -0,0 +1,8 @@
package io.github.ollama4j.models.generate;
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
import java.util.function.Consumer;
public interface OllamaTokenHandler extends Consumer<OllamaChatResponseModel> {
}

View File

@ -4,9 +4,8 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.generate.OllamaTokenHandler;
import io.github.ollama4j.models.response.OllamaErrorResponse; import io.github.ollama4j.models.response.OllamaErrorResponse;
import io.github.ollama4j.models.generate.OllamaStreamHandler;
import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -29,7 +28,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
private OllamaChatStreamObserver streamObserver; private OllamaTokenHandler tokenHandler;
public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose); super(host, basicAuth, requestTimeoutSeconds, verbose);
@ -60,8 +59,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
OllamaChatMessage message = ollamaResponseModel.getMessage(); OllamaChatMessage message = ollamaResponseModel.getMessage();
if(message != null) { if(message != null) {
responseBuffer.append(message.getContent()); responseBuffer.append(message.getContent());
if (streamObserver != null) { if (tokenHandler != null) {
streamObserver.notify(ollamaResponseModel); tokenHandler.accept(ollamaResponseModel);
} }
} }
return ollamaResponseModel.isDone(); return ollamaResponseModel.isDone();
@ -71,9 +70,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
} }
} }
public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler) public OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaChatStreamObserver(streamHandler); this.tokenHandler = tokenHandler;
return callSync(body); return callSync(body);
} }
@ -86,7 +85,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
.POST( .POST(
body.getBodyPublisher()); body.getBodyPublisher());
HttpRequest request = requestBuilder.build(); HttpRequest request = requestBuilder.build();
if (isVerbose()) LOG.info("Asking model: " + body.toString()); if (isVerbose()) LOG.info("Asking model: " + body);
HttpResponse<InputStream> response = HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());

View File

@ -321,7 +321,7 @@ class TestRealAPIs {
assertEquals(1, function.getArguments().size()); assertEquals(1, function.getArguments().size());
Object noOfDigits = function.getArguments().get("noOfDigits"); Object noOfDigits = function.getArguments().get("noOfDigits");
assertNotNull(noOfDigits); assertNotNull(noOfDigits);
assertEquals("5",noOfDigits); assertEquals("5", noOfDigits.toString());
assertTrue(chatResult.getChatHistory().size()>2); assertTrue(chatResult.getChatHistory().size()>2);
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
assertNull(finalToolCalls); assertNull(finalToolCalls);