forked from Mirror/ollama4j
		
	Adds implicit tool calling for streamed chat requests (requires Ollama v0.4.6)
This commit is contained in:
		@@ -777,22 +777,27 @@ public class OllamaAPI {
 | 
			
		||||
            result = requestCaller.call(request, streamHandler);
 | 
			
		||||
        } else {
 | 
			
		||||
            result = requestCaller.callSync(request);
 | 
			
		||||
            // check if toolCallIsWanted
 | 
			
		||||
            List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
			
		||||
            int toolCallTries = 0;
 | 
			
		||||
            while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries <3){
 | 
			
		||||
                for (OllamaChatToolCalls toolCall : toolCalls){
 | 
			
		||||
                    String toolName = toolCall.getFunction().getName();
 | 
			
		||||
                    ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
 | 
			
		||||
                    Map<String, Object> arguments = toolCall.getFunction().getArguments();
 | 
			
		||||
                    Object res = toolFunction.apply(arguments);
 | 
			
		||||
                    request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[ToolCall-Result]" + toolName + "(" + arguments.keySet() +") : " + res + "[/ToolCall-Result]"));
 | 
			
		||||
                }
 | 
			
		||||
                result = requestCaller.callSync(request);
 | 
			
		||||
                toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
			
		||||
                toolCallTries++;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // check if toolCallIsWanted
 | 
			
		||||
        List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
			
		||||
        int toolCallTries = 0;
 | 
			
		||||
        while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries <3){
 | 
			
		||||
            for (OllamaChatToolCalls toolCall : toolCalls){
 | 
			
		||||
                String toolName = toolCall.getFunction().getName();
 | 
			
		||||
                ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
 | 
			
		||||
                Map<String, Object> arguments = toolCall.getFunction().getArguments();
 | 
			
		||||
                Object res = toolFunction.apply(arguments);
 | 
			
		||||
                request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if (streamHandler != null) {
 | 
			
		||||
                result = requestCaller.call(request, streamHandler);
 | 
			
		||||
            } else {
 | 
			
		||||
                result = requestCaller.callSync(request);
 | 
			
		||||
            }
 | 
			
		||||
            toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
			
		||||
            toolCallTries++;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return result;
 | 
			
		||||
 
 | 
			
		||||
@@ -3,13 +3,10 @@ package io.github.ollama4j.models.request;
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import com.fasterxml.jackson.core.type.TypeReference;
 | 
			
		||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatResult;
 | 
			
		||||
import io.github.ollama4j.models.chat.*;
 | 
			
		||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
 | 
			
		||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
			
		||||
import io.github.ollama4j.tools.Tools;
 | 
			
		||||
import io.github.ollama4j.utils.Utils;
 | 
			
		||||
import org.slf4j.Logger;
 | 
			
		||||
import org.slf4j.LoggerFactory;
 | 
			
		||||
@@ -23,6 +20,7 @@ import java.net.http.HttpClient;
 | 
			
		||||
import java.net.http.HttpRequest;
 | 
			
		||||
import java.net.http.HttpResponse;
 | 
			
		||||
import java.nio.charset.StandardCharsets;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Specialization class for requests
 | 
			
		||||
@@ -96,6 +94,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
			
		||||
        InputStream responseBodyStream = response.body();
 | 
			
		||||
        StringBuilder responseBuffer = new StringBuilder();
 | 
			
		||||
        OllamaChatResponseModel ollamaChatResponseModel = null;
 | 
			
		||||
        List<OllamaChatToolCalls> wantedToolsForStream = null;
 | 
			
		||||
        try (BufferedReader reader =
 | 
			
		||||
                     new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
 | 
			
		||||
 | 
			
		||||
@@ -120,6 +119,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
			
		||||
                } else {
 | 
			
		||||
                    boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
 | 
			
		||||
                        ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
 | 
			
		||||
                    if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){
 | 
			
		||||
                        wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
 | 
			
		||||
                    }
 | 
			
		||||
                    if (finished && body.stream) {
 | 
			
		||||
                        ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
 | 
			
		||||
                        break;
 | 
			
		||||
@@ -131,6 +133,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
			
		||||
            LOG.error("Status code " + statusCode);
 | 
			
		||||
            throw new OllamaBaseException(responseBuffer.toString());
 | 
			
		||||
        } else {
 | 
			
		||||
            if(wantedToolsForStream != null) {
 | 
			
		||||
                ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
 | 
			
		||||
            }
 | 
			
		||||
            OllamaChatResult ollamaResult =
 | 
			
		||||
                    new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
 | 
			
		||||
            if (isVerbose()) LOG.info("Model response: " + ollamaResult);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user