forked from Mirror/ollama4j
		
	Merge pull request #88 from seeseemelk/feature/token-streamer
Add ability to stream tokens in chat
This commit is contained in:
		@@ -11,6 +11,7 @@ import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
			
		||||
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
 | 
			
		||||
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
 | 
			
		||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
			
		||||
import io.github.ollama4j.models.generate.OllamaTokenHandler;
 | 
			
		||||
import io.github.ollama4j.models.ps.ModelsProcessResponse;
 | 
			
		||||
import io.github.ollama4j.models.request.*;
 | 
			
		||||
import io.github.ollama4j.models.response.*;
 | 
			
		||||
@@ -785,15 +786,34 @@ public class OllamaAPI {
 | 
			
		||||
     * @throws InterruptedException if the operation is interrupted
 | 
			
		||||
     */
 | 
			
		||||
    public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Ask a question to a model using an {@link OllamaChatRequest}. This can be constructed using an {@link OllamaChatRequestBuilder}.
 | 
			
		||||
     * <p>
 | 
			
		||||
     * Hint: the OllamaChatRequestModel#getStream() property is not implemented.
 | 
			
		||||
     *
 | 
			
		||||
     * @param request       request object to be sent to the server
 | 
			
		||||
     * @param tokenHandler  callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated)
 | 
			
		||||
     * @return {@link OllamaChatResult}
 | 
			
		||||
     * @throws OllamaBaseException  any response code than 200 has been returned
 | 
			
		||||
     * @throws IOException          in case the responseStream can not be read
 | 
			
		||||
     * @throws InterruptedException in case the server is not reachable or network issues happen
 | 
			
		||||
     * @throws OllamaBaseException  if the response indicates an error status
 | 
			
		||||
     * @throws IOException          if an I/O error occurs during the HTTP request
 | 
			
		||||
     * @throws InterruptedException if the operation is interrupted
 | 
			
		||||
     */
 | 
			
		||||
    public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
			
		||||
        OllamaChatResult result;
 | 
			
		||||
 | 
			
		||||
        // add all registered tools to Request
 | 
			
		||||
        request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
 | 
			
		||||
 | 
			
		||||
        if (streamHandler != null) {
 | 
			
		||||
        if (tokenHandler != null) {
 | 
			
		||||
            request.setStream(true);
 | 
			
		||||
            result = requestCaller.call(request, streamHandler);
 | 
			
		||||
            result = requestCaller.call(request, tokenHandler);
 | 
			
		||||
        } else {
 | 
			
		||||
            result = requestCaller.callSync(request);
 | 
			
		||||
        }
 | 
			
		||||
@@ -810,8 +830,8 @@ public class OllamaAPI {
 | 
			
		||||
                request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if (streamHandler != null) {
 | 
			
		||||
                result = requestCaller.call(request, streamHandler);
 | 
			
		||||
            if (tokenHandler != null) {
 | 
			
		||||
                result = requestCaller.call(request, tokenHandler);
 | 
			
		||||
            } else {
 | 
			
		||||
                result = requestCaller.callSync(request);
 | 
			
		||||
            }
 | 
			
		||||
 
 | 
			
		||||
@@ -1,31 +1,19 @@
 | 
			
		||||
package io.github.ollama4j.models.chat;
 | 
			
		||||
 | 
			
		||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
			
		||||
import io.github.ollama4j.models.generate.OllamaTokenHandler;
 | 
			
		||||
import lombok.RequiredArgsConstructor;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
public class OllamaChatStreamObserver {
 | 
			
		||||
 | 
			
		||||
    private OllamaStreamHandler streamHandler;
 | 
			
		||||
 | 
			
		||||
    private List<OllamaChatResponseModel> responseParts = new ArrayList<>();
 | 
			
		||||
 | 
			
		||||
@RequiredArgsConstructor
 | 
			
		||||
public class OllamaChatStreamObserver implements OllamaTokenHandler {
 | 
			
		||||
    private final OllamaStreamHandler streamHandler;
 | 
			
		||||
    private String message = "";
 | 
			
		||||
 | 
			
		||||
    public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) {
 | 
			
		||||
        this.streamHandler = streamHandler;
 | 
			
		||||
    @Override
 | 
			
		||||
    public void accept(OllamaChatResponseModel token) {
 | 
			
		||||
        if (streamHandler != null) {
 | 
			
		||||
            message += token.getMessage().getContent();
 | 
			
		||||
            streamHandler.accept(message);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public void notify(OllamaChatResponseModel currentResponsePart) {
 | 
			
		||||
        responseParts.add(currentResponsePart);
 | 
			
		||||
        handleCurrentResponsePart(currentResponsePart);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart) {
 | 
			
		||||
        message = message + currentResponsePart.getMessage().getContent();
 | 
			
		||||
        streamHandler.accept(message);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,8 @@
 | 
			
		||||
package io.github.ollama4j.models.generate;
 | 
			
		||||
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
 | 
			
		||||
 | 
			
		||||
import java.util.function.Consumer;
 | 
			
		||||
 | 
			
		||||
public interface OllamaTokenHandler extends Consumer<OllamaChatResponseModel> {
 | 
			
		||||
}
 | 
			
		||||
@@ -4,9 +4,8 @@ import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import com.fasterxml.jackson.core.type.TypeReference;
 | 
			
		||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
			
		||||
import io.github.ollama4j.models.chat.*;
 | 
			
		||||
import io.github.ollama4j.models.generate.OllamaTokenHandler;
 | 
			
		||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
 | 
			
		||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
			
		||||
import io.github.ollama4j.tools.Tools;
 | 
			
		||||
import io.github.ollama4j.utils.Utils;
 | 
			
		||||
import org.slf4j.Logger;
 | 
			
		||||
import org.slf4j.LoggerFactory;
 | 
			
		||||
@@ -29,7 +28,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
			
		||||
 | 
			
		||||
    private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
 | 
			
		||||
 | 
			
		||||
    private OllamaChatStreamObserver streamObserver;
 | 
			
		||||
    private OllamaTokenHandler tokenHandler;
 | 
			
		||||
 | 
			
		||||
    public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
 | 
			
		||||
        super(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
			
		||||
@@ -60,8 +59,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
			
		||||
            OllamaChatMessage message = ollamaResponseModel.getMessage();
 | 
			
		||||
            if(message != null) {
 | 
			
		||||
                responseBuffer.append(message.getContent());
 | 
			
		||||
                if (streamObserver != null) {
 | 
			
		||||
                    streamObserver.notify(ollamaResponseModel);
 | 
			
		||||
                if (tokenHandler != null) {
 | 
			
		||||
                    tokenHandler.accept(ollamaResponseModel);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            return ollamaResponseModel.isDone();
 | 
			
		||||
@@ -71,9 +70,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler)
 | 
			
		||||
    public OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler)
 | 
			
		||||
            throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        streamObserver = new OllamaChatStreamObserver(streamHandler);
 | 
			
		||||
        this.tokenHandler = tokenHandler;
 | 
			
		||||
        return callSync(body);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -86,7 +85,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
			
		||||
                        .POST(
 | 
			
		||||
                                body.getBodyPublisher());
 | 
			
		||||
        HttpRequest request = requestBuilder.build();
 | 
			
		||||
        if (isVerbose()) LOG.info("Asking model: " + body.toString());
 | 
			
		||||
        if (isVerbose()) LOG.info("Asking model: " + body);
 | 
			
		||||
        HttpResponse<InputStream> response =
 | 
			
		||||
                httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user