diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 810b2c4..3b525df 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -777,22 +777,27 @@ public class OllamaAPI { result = requestCaller.call(request, streamHandler); } else { result = requestCaller.callSync(request); - // check if toolCallIsWanted - List toolCalls = result.getResponseModel().getMessage().getToolCalls(); - int toolCallTries = 0; - while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries <3){ - for (OllamaChatToolCalls toolCall : toolCalls){ - String toolName = toolCall.getFunction().getName(); - ToolFunction toolFunction = toolRegistry.getToolFunction(toolName); - Map arguments = toolCall.getFunction().getArguments(); - Object res = toolFunction.apply(arguments); - request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[ToolCall-Result]" + toolName + "(" + arguments.keySet() +") : " + res + "[/ToolCall-Result]")); - } - result = requestCaller.callSync(request); - toolCalls = result.getResponseModel().getMessage().getToolCalls(); - toolCallTries++; + } + + // check if toolCallIsWanted + List toolCalls = result.getResponseModel().getMessage().getToolCalls(); + int toolCallTries = 0; + while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries <3){ + for (OllamaChatToolCalls toolCall : toolCalls){ + String toolName = toolCall.getFunction().getName(); + ToolFunction toolFunction = toolRegistry.getToolFunction(toolName); + Map arguments = toolCall.getFunction().getArguments(); + Object res = toolFunction.apply(arguments); + request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]")); } + if (streamHandler != null) { + result = requestCaller.call(request, streamHandler); + } else { + result = requestCaller.callSync(request); + } + toolCalls = result.getResponseModel().getMessage().getToolCalls(); + toolCallTries++; } return result; diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java index 71c2b9b..57c9ee3 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java @@ -3,13 +3,10 @@ package io.github.ollama4j.models.request; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import io.github.ollama4j.exceptions.OllamaBaseException; -import io.github.ollama4j.models.chat.OllamaChatMessage; -import io.github.ollama4j.models.chat.OllamaChatRequest; -import io.github.ollama4j.models.chat.OllamaChatResult; +import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.response.OllamaErrorResponse; -import io.github.ollama4j.models.chat.OllamaChatResponseModel; -import io.github.ollama4j.models.chat.OllamaChatStreamObserver; import io.github.ollama4j.models.generate.OllamaStreamHandler; +import io.github.ollama4j.tools.Tools; import io.github.ollama4j.utils.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -23,6 +20,7 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.charset.StandardCharsets; +import java.util.List; /** * Specialization class for requests @@ -96,6 +94,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { InputStream responseBodyStream = response.body(); StringBuilder responseBuffer = new StringBuilder(); OllamaChatResponseModel ollamaChatResponseModel = null; + List wantedToolsForStream = null; try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { @@ -120,6 +119,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { } else { boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); + if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){ + wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls(); + } if (finished && body.stream) { ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString()); break; @@ -131,6 +133,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { LOG.error("Status code " + statusCode); throw new OllamaBaseException(responseBuffer.toString()); } else { + if(wantedToolsForStream != null) { + ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream); + } OllamaChatResult ollamaResult = new OllamaChatResult(ollamaChatResponseModel,body.getMessages()); if (isVerbose()) LOG.info("Model response: " + ollamaResult); diff --git a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java index fbf518d..668a5dc 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java @@ -265,7 +265,7 @@ class TestRealAPIs { OllamaChatRequest requestModel = builder .withMessage(OllamaChatMessageRole.USER, - "Give me the details of the employee named 'Rahul Kumar'?") + "Give me the ID of the employee named 'Rahul Kumar'?") .build(); OllamaChatResult chatResult = ollamaAPI.chat(requestModel); @@ -288,6 +288,63 @@ class TestRealAPIs { } } + @Test + @Order(3) + void testChatWithToolsAndStream() { + testEndpointReachability(); + try { + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get employee details from the database") + .toolPrompt( + Tools.PromptFuncDefinition.builder().type("function").function( + Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name("get-employee-details") + .description("Get employee details from the database") + .parameters( + Tools.PromptFuncDefinition.Parameters.builder() + .type("object") + .properties( + new Tools.PropsBuilder() + .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()) + .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()) + .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()) + .build() + ) + .required(List.of("employee-name")) + .build() + ).build() + ).build() + ) + .toolFunction(new DBQueryFunction()) + .build(); + + ollamaAPI.registerTool(databaseQueryToolSpecification); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Give me the ID of the employee named 'Rahul Kumar'?") + .build(); + + StringBuffer sb = new StringBuffer(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length()); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertNotNull(chatResult.getResponseModel().getMessage().getContent()); + assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } + } + @Test @Order(3) void testChatWithStream() {