forked from Mirror/ollama4j
		
	Adds implicit tool calling for streamed chat requests (requires Ollama v0.4.6)
This commit is contained in:
		@@ -777,6 +777,8 @@ public class OllamaAPI {
 | 
				
			|||||||
            result = requestCaller.call(request, streamHandler);
 | 
					            result = requestCaller.call(request, streamHandler);
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            result = requestCaller.callSync(request);
 | 
					            result = requestCaller.callSync(request);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // check if toolCallIsWanted
 | 
					        // check if toolCallIsWanted
 | 
				
			||||||
        List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
					        List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
				
			||||||
        int toolCallTries = 0;
 | 
					        int toolCallTries = 0;
 | 
				
			||||||
@@ -786,13 +788,16 @@ public class OllamaAPI {
 | 
				
			|||||||
                ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
 | 
					                ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
 | 
				
			||||||
                Map<String, Object> arguments = toolCall.getFunction().getArguments();
 | 
					                Map<String, Object> arguments = toolCall.getFunction().getArguments();
 | 
				
			||||||
                Object res = toolFunction.apply(arguments);
 | 
					                Object res = toolFunction.apply(arguments);
 | 
				
			||||||
                    request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[ToolCall-Result]" + toolName + "(" + arguments.keySet() +") : " + res + "[/ToolCall-Result]"));
 | 
					                request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
                result = requestCaller.callSync(request);
 | 
					 | 
				
			||||||
                toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
					 | 
				
			||||||
                toolCallTries++;
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (streamHandler != null) {
 | 
				
			||||||
 | 
					                result = requestCaller.call(request, streamHandler);
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					                result = requestCaller.callSync(request);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
				
			||||||
 | 
					            toolCallTries++;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return result;
 | 
					        return result;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,13 +3,10 @@ package io.github.ollama4j.models.request;
 | 
				
			|||||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
					import com.fasterxml.jackson.core.JsonProcessingException;
 | 
				
			||||||
import com.fasterxml.jackson.core.type.TypeReference;
 | 
					import com.fasterxml.jackson.core.type.TypeReference;
 | 
				
			||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
					import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
 | 
					import io.github.ollama4j.models.chat.*;
 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatResult;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
 | 
					import io.github.ollama4j.models.response.OllamaErrorResponse;
 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
					import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.Tools;
 | 
				
			||||||
import io.github.ollama4j.utils.Utils;
 | 
					import io.github.ollama4j.utils.Utils;
 | 
				
			||||||
import org.slf4j.Logger;
 | 
					import org.slf4j.Logger;
 | 
				
			||||||
import org.slf4j.LoggerFactory;
 | 
					import org.slf4j.LoggerFactory;
 | 
				
			||||||
@@ -23,6 +20,7 @@ import java.net.http.HttpClient;
 | 
				
			|||||||
import java.net.http.HttpRequest;
 | 
					import java.net.http.HttpRequest;
 | 
				
			||||||
import java.net.http.HttpResponse;
 | 
					import java.net.http.HttpResponse;
 | 
				
			||||||
import java.nio.charset.StandardCharsets;
 | 
					import java.nio.charset.StandardCharsets;
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Specialization class for requests
 | 
					 * Specialization class for requests
 | 
				
			||||||
@@ -96,6 +94,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
				
			|||||||
        InputStream responseBodyStream = response.body();
 | 
					        InputStream responseBodyStream = response.body();
 | 
				
			||||||
        StringBuilder responseBuffer = new StringBuilder();
 | 
					        StringBuilder responseBuffer = new StringBuilder();
 | 
				
			||||||
        OllamaChatResponseModel ollamaChatResponseModel = null;
 | 
					        OllamaChatResponseModel ollamaChatResponseModel = null;
 | 
				
			||||||
 | 
					        List<OllamaChatToolCalls> wantedToolsForStream = null;
 | 
				
			||||||
        try (BufferedReader reader =
 | 
					        try (BufferedReader reader =
 | 
				
			||||||
                     new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
 | 
					                     new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -120,6 +119,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
				
			|||||||
                } else {
 | 
					                } else {
 | 
				
			||||||
                    boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
 | 
					                    boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
 | 
				
			||||||
                        ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
 | 
					                        ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
 | 
				
			||||||
 | 
					                    if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){
 | 
				
			||||||
 | 
					                        wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
                    if (finished && body.stream) {
 | 
					                    if (finished && body.stream) {
 | 
				
			||||||
                        ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
 | 
					                        ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
 | 
				
			||||||
                        break;
 | 
					                        break;
 | 
				
			||||||
@@ -131,6 +133,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
				
			|||||||
            LOG.error("Status code " + statusCode);
 | 
					            LOG.error("Status code " + statusCode);
 | 
				
			||||||
            throw new OllamaBaseException(responseBuffer.toString());
 | 
					            throw new OllamaBaseException(responseBuffer.toString());
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
 | 
					            if(wantedToolsForStream != null) {
 | 
				
			||||||
 | 
					                ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
            OllamaChatResult ollamaResult =
 | 
					            OllamaChatResult ollamaResult =
 | 
				
			||||||
                    new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
 | 
					                    new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
 | 
				
			||||||
            if (isVerbose()) LOG.info("Model response: " + ollamaResult);
 | 
					            if (isVerbose()) LOG.info("Model response: " + ollamaResult);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -265,7 +265,7 @@ class TestRealAPIs {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            OllamaChatRequest requestModel = builder
 | 
					            OllamaChatRequest requestModel = builder
 | 
				
			||||||
                    .withMessage(OllamaChatMessageRole.USER,
 | 
					                    .withMessage(OllamaChatMessageRole.USER,
 | 
				
			||||||
                            "Give me the details of the employee named 'Rahul Kumar'?")
 | 
					                            "Give me the ID of the employee named 'Rahul Kumar'?")
 | 
				
			||||||
                    .build();
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
					            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
@@ -288,6 +288,63 @@ class TestRealAPIs {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    @Order(3)
 | 
				
			||||||
 | 
					    void testChatWithToolsAndStream() {
 | 
				
			||||||
 | 
					        testEndpointReachability();
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
 | 
				
			||||||
 | 
					            final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
 | 
				
			||||||
 | 
					                    .functionName("get-employee-details")
 | 
				
			||||||
 | 
					                    .functionDescription("Get employee details from the database")
 | 
				
			||||||
 | 
					                    .toolPrompt(
 | 
				
			||||||
 | 
					                            Tools.PromptFuncDefinition.builder().type("function").function(
 | 
				
			||||||
 | 
					                                    Tools.PromptFuncDefinition.PromptFuncSpec.builder()
 | 
				
			||||||
 | 
					                                            .name("get-employee-details")
 | 
				
			||||||
 | 
					                                            .description("Get employee details from the database")
 | 
				
			||||||
 | 
					                                            .parameters(
 | 
				
			||||||
 | 
					                                                    Tools.PromptFuncDefinition.Parameters.builder()
 | 
				
			||||||
 | 
					                                                            .type("object")
 | 
				
			||||||
 | 
					                                                            .properties(
 | 
				
			||||||
 | 
					                                                                    new Tools.PropsBuilder()
 | 
				
			||||||
 | 
					                                                                            .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
 | 
				
			||||||
 | 
					                                                                            .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
 | 
				
			||||||
 | 
					                                                                            .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
 | 
				
			||||||
 | 
					                                                                            .build()
 | 
				
			||||||
 | 
					                                                            )
 | 
				
			||||||
 | 
					                                                            .required(List.of("employee-name"))
 | 
				
			||||||
 | 
					                                                            .build()
 | 
				
			||||||
 | 
					                                            ).build()
 | 
				
			||||||
 | 
					                            ).build()
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    .toolFunction(new DBQueryFunction())
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            ollamaAPI.registerTool(databaseQueryToolSpecification);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OllamaChatRequest requestModel = builder
 | 
				
			||||||
 | 
					                    .withMessage(OllamaChatMessageRole.USER,
 | 
				
			||||||
 | 
					                            "Give me the ID of the employee named 'Rahul Kumar'?")
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            StringBuffer sb = new StringBuffer();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
 | 
				
			||||||
 | 
					                LOG.info(s);
 | 
				
			||||||
 | 
					                String substring = s.substring(sb.toString().length());
 | 
				
			||||||
 | 
					                LOG.info(substring);
 | 
				
			||||||
 | 
					                sb.append(substring);
 | 
				
			||||||
 | 
					            });
 | 
				
			||||||
 | 
					            assertNotNull(chatResult);
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel());
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel().getMessage().getContent());
 | 
				
			||||||
 | 
					            assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
 | 
				
			||||||
 | 
					        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
				
			||||||
 | 
					            fail(e);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    @Order(3)
 | 
					    @Order(3)
 | 
				
			||||||
    void testChatWithStream() {
 | 
					    void testChatWithStream() {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user