From 69f6fd81cf614a5a10a27b87499f9c1dfbc88169 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Sat, 7 Dec 2024 01:04:13 +0100 Subject: [PATCH] Enables in chat tool calling --- src/main/java/io/github/ollama4j/OllamaAPI.java | 16 ++++++++++++++++ .../ollama4j/tools/OllamaToolCallsFunction.java | 2 +- .../ollama4j/integrationtests/TestRealAPIs.java | 10 ++++++---- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index d76ecd9..810b2c4 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -777,6 +777,22 @@ public class OllamaAPI { result = requestCaller.call(request, streamHandler); } else { result = requestCaller.callSync(request); + // check if toolCallIsWanted + List toolCalls = result.getResponseModel().getMessage().getToolCalls(); + int toolCallTries = 0; + while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries <3){ + for (OllamaChatToolCalls toolCall : toolCalls){ + String toolName = toolCall.getFunction().getName(); + ToolFunction toolFunction = toolRegistry.getToolFunction(toolName); + Map arguments = toolCall.getFunction().getArguments(); + Object res = toolFunction.apply(arguments); + request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[ToolCall-Result]" + toolName + "(" + arguments.keySet() +") : " + res + "[/ToolCall-Result]")); + } + result = requestCaller.callSync(request); + toolCalls = result.getResponseModel().getMessage().getToolCalls(); + toolCallTries++; + } + } return result; diff --git a/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java b/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java index dfa4d84..4be7194 100644 --- a/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java +++ b/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java @@ -12,5 +12,5 @@ import java.util.Map; public class OllamaToolCallsFunction { private String name; - private Map arguments; + private Map arguments; } diff --git a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java index fd834dd..fbf518d 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java @@ -273,14 +273,16 @@ class TestRealAPIs { assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName()); - List toolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); assertEquals(1, toolCalls.size()); assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName()); assertEquals(1, toolCalls.get(0).getFunction().getArguments().size()); - String employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name"); + Object employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name"); assertNotNull(employeeName); assertEquals("Rahul Kumar",employeeName); - assertEquals(2, chatResult.getChatHistory().size()); + assertTrue(chatResult.getChatHistory().size()>2); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls); } catch (IOException | OllamaBaseException | InterruptedException e) { fail(e); } @@ -448,7 +450,7 @@ class DBQueryFunction implements ToolFunction { @Override public Object apply(Map arguments) { // perform DB operations here - return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString()); + return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone")); } }