forked from Mirror/ollama4j
		
	Merge pull request #88 from seeseemelk/feature/token-streamer
Add ability to stream tokens in chat
This commit is contained in:
		@@ -11,6 +11,7 @@ import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
				
			|||||||
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
 | 
					import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
 | 
				
			||||||
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
 | 
					import io.github.ollama4j.models.generate.OllamaGenerateRequest;
 | 
				
			||||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
					import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
				
			||||||
 | 
					import io.github.ollama4j.models.generate.OllamaTokenHandler;
 | 
				
			||||||
import io.github.ollama4j.models.ps.ModelsProcessResponse;
 | 
					import io.github.ollama4j.models.ps.ModelsProcessResponse;
 | 
				
			||||||
import io.github.ollama4j.models.request.*;
 | 
					import io.github.ollama4j.models.request.*;
 | 
				
			||||||
import io.github.ollama4j.models.response.*;
 | 
					import io.github.ollama4j.models.response.*;
 | 
				
			||||||
@@ -785,15 +786,34 @@ public class OllamaAPI {
 | 
				
			|||||||
     * @throws InterruptedException if the operation is interrupted
 | 
					     * @throws InterruptedException if the operation is interrupted
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
 | 
					    public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
 | 
				
			||||||
 | 
					        return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Ask a question to a model using an {@link OllamaChatRequest}. This can be constructed using an {@link OllamaChatRequestBuilder}.
 | 
				
			||||||
 | 
					     * <p>
 | 
				
			||||||
 | 
					     * Hint: the OllamaChatRequestModel#getStream() property is not implemented.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param request       request object to be sent to the server
 | 
				
			||||||
 | 
					     * @param tokenHandler  callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated)
 | 
				
			||||||
 | 
					     * @return {@link OllamaChatResult}
 | 
				
			||||||
 | 
					     * @throws OllamaBaseException  any response code than 200 has been returned
 | 
				
			||||||
 | 
					     * @throws IOException          in case the responseStream can not be read
 | 
				
			||||||
 | 
					     * @throws InterruptedException in case the server is not reachable or network issues happen
 | 
				
			||||||
 | 
					     * @throws OllamaBaseException  if the response indicates an error status
 | 
				
			||||||
 | 
					     * @throws IOException          if an I/O error occurs during the HTTP request
 | 
				
			||||||
 | 
					     * @throws InterruptedException if the operation is interrupted
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException {
 | 
				
			||||||
        OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
					        OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
				
			||||||
        OllamaChatResult result;
 | 
					        OllamaChatResult result;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // add all registered tools to Request
 | 
					        // add all registered tools to Request
 | 
				
			||||||
        request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
 | 
					        request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (streamHandler != null) {
 | 
					        if (tokenHandler != null) {
 | 
				
			||||||
            request.setStream(true);
 | 
					            request.setStream(true);
 | 
				
			||||||
            result = requestCaller.call(request, streamHandler);
 | 
					            result = requestCaller.call(request, tokenHandler);
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            result = requestCaller.callSync(request);
 | 
					            result = requestCaller.callSync(request);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@@ -810,8 +830,8 @@ public class OllamaAPI {
 | 
				
			|||||||
                request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
 | 
					                request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if (streamHandler != null) {
 | 
					            if (tokenHandler != null) {
 | 
				
			||||||
                result = requestCaller.call(request, streamHandler);
 | 
					                result = requestCaller.call(request, tokenHandler);
 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
                result = requestCaller.callSync(request);
 | 
					                result = requestCaller.callSync(request);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,31 +1,19 @@
 | 
				
			|||||||
package io.github.ollama4j.models.chat;
 | 
					package io.github.ollama4j.models.chat;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
					import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
				
			||||||
 | 
					import io.github.ollama4j.models.generate.OllamaTokenHandler;
 | 
				
			||||||
 | 
					import lombok.RequiredArgsConstructor;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.util.ArrayList;
 | 
					@RequiredArgsConstructor
 | 
				
			||||||
import java.util.List;
 | 
					public class OllamaChatStreamObserver implements OllamaTokenHandler {
 | 
				
			||||||
 | 
					    private final OllamaStreamHandler streamHandler;
 | 
				
			||||||
public class OllamaChatStreamObserver {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    private OllamaStreamHandler streamHandler;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    private List<OllamaChatResponseModel> responseParts = new ArrayList<>();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    private String message = "";
 | 
					    private String message = "";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) {
 | 
					    @Override
 | 
				
			||||||
        this.streamHandler = streamHandler;
 | 
					    public void accept(OllamaChatResponseModel token) {
 | 
				
			||||||
 | 
					        if (streamHandler != null) {
 | 
				
			||||||
 | 
					            message += token.getMessage().getContent();
 | 
				
			||||||
 | 
					            streamHandler.accept(message);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    public void notify(OllamaChatResponseModel currentResponsePart) {
 | 
					 | 
				
			||||||
        responseParts.add(currentResponsePart);
 | 
					 | 
				
			||||||
        handleCurrentResponsePart(currentResponsePart);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart) {
 | 
					 | 
				
			||||||
        message = message + currentResponsePart.getMessage().getContent();
 | 
					 | 
				
			||||||
        streamHandler.accept(message);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,8 @@
 | 
				
			|||||||
 | 
					package io.github.ollama4j.models.generate;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import io.github.ollama4j.models.chat.OllamaChatResponseModel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.function.Consumer;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					public interface OllamaTokenHandler extends Consumer<OllamaChatResponseModel> {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -4,9 +4,8 @@ import com.fasterxml.jackson.core.JsonProcessingException;
 | 
				
			|||||||
import com.fasterxml.jackson.core.type.TypeReference;
 | 
					import com.fasterxml.jackson.core.type.TypeReference;
 | 
				
			||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
					import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
				
			||||||
import io.github.ollama4j.models.chat.*;
 | 
					import io.github.ollama4j.models.chat.*;
 | 
				
			||||||
 | 
					import io.github.ollama4j.models.generate.OllamaTokenHandler;
 | 
				
			||||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
 | 
					import io.github.ollama4j.models.response.OllamaErrorResponse;
 | 
				
			||||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
					 | 
				
			||||||
import io.github.ollama4j.tools.Tools;
 | 
					 | 
				
			||||||
import io.github.ollama4j.utils.Utils;
 | 
					import io.github.ollama4j.utils.Utils;
 | 
				
			||||||
import org.slf4j.Logger;
 | 
					import org.slf4j.Logger;
 | 
				
			||||||
import org.slf4j.LoggerFactory;
 | 
					import org.slf4j.LoggerFactory;
 | 
				
			||||||
@@ -29,7 +28,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
 | 
					    private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private OllamaChatStreamObserver streamObserver;
 | 
					    private OllamaTokenHandler tokenHandler;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
 | 
					    public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
 | 
				
			||||||
        super(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
					        super(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
				
			||||||
@@ -60,8 +59,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
				
			|||||||
            OllamaChatMessage message = ollamaResponseModel.getMessage();
 | 
					            OllamaChatMessage message = ollamaResponseModel.getMessage();
 | 
				
			||||||
            if(message != null) {
 | 
					            if(message != null) {
 | 
				
			||||||
                responseBuffer.append(message.getContent());
 | 
					                responseBuffer.append(message.getContent());
 | 
				
			||||||
                if (streamObserver != null) {
 | 
					                if (tokenHandler != null) {
 | 
				
			||||||
                    streamObserver.notify(ollamaResponseModel);
 | 
					                    tokenHandler.accept(ollamaResponseModel);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            return ollamaResponseModel.isDone();
 | 
					            return ollamaResponseModel.isDone();
 | 
				
			||||||
@@ -71,9 +70,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler)
 | 
					    public OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler)
 | 
				
			||||||
            throws OllamaBaseException, IOException, InterruptedException {
 | 
					            throws OllamaBaseException, IOException, InterruptedException {
 | 
				
			||||||
        streamObserver = new OllamaChatStreamObserver(streamHandler);
 | 
					        this.tokenHandler = tokenHandler;
 | 
				
			||||||
        return callSync(body);
 | 
					        return callSync(body);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -86,7 +85,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
				
			|||||||
                        .POST(
 | 
					                        .POST(
 | 
				
			||||||
                                body.getBodyPublisher());
 | 
					                                body.getBodyPublisher());
 | 
				
			||||||
        HttpRequest request = requestBuilder.build();
 | 
					        HttpRequest request = requestBuilder.build();
 | 
				
			||||||
        if (isVerbose()) LOG.info("Asking model: " + body.toString());
 | 
					        if (isVerbose()) LOG.info("Asking model: " + body);
 | 
				
			||||||
        HttpResponse<InputStream> response =
 | 
					        HttpResponse<InputStream> response =
 | 
				
			||||||
                httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
 | 
					                httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -321,7 +321,7 @@ class TestRealAPIs {
 | 
				
			|||||||
            assertEquals(1, function.getArguments().size());
 | 
					            assertEquals(1, function.getArguments().size());
 | 
				
			||||||
            Object noOfDigits = function.getArguments().get("noOfDigits");
 | 
					            Object noOfDigits = function.getArguments().get("noOfDigits");
 | 
				
			||||||
            assertNotNull(noOfDigits);
 | 
					            assertNotNull(noOfDigits);
 | 
				
			||||||
            assertEquals("5",noOfDigits);
 | 
					            assertEquals("5", noOfDigits.toString());
 | 
				
			||||||
            assertTrue(chatResult.getChatHistory().size()>2);
 | 
					            assertTrue(chatResult.getChatHistory().size()>2);
 | 
				
			||||||
            List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
 | 
					            List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
 | 
				
			||||||
            assertNull(finalToolCalls);
 | 
					            assertNull(finalToolCalls);
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user