diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index ba2488a..fdab795 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -1294,8 +1294,11 @@ public class OllamaAPI { } Map arguments = toolCall.getFunction().getArguments(); Object res = toolFunction.apply(arguments); + String argumentKeys = arguments.keySet().stream() + .map(Object::toString) + .collect(Collectors.joining(", ")); request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, - "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]")); + "[TOOL_RESULTS] " + toolName + "(" + argumentKeys + "): " + res + " [/TOOL_RESULTS]")); } if (tokenHandler != null) { diff --git a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java index b720a24..c152588 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java +++ b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java @@ -105,8 +105,7 @@ class OllamaAPIIntegrationTest { @Test @Order(2) - void testListModelsAPI() - throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { // Fetch the list of models List models = api.listModels(); // Assert that the models list is not null @@ -126,8 +125,7 @@ class OllamaAPIIntegrationTest { @Test @Order(3) - void testPullModelAPI() - throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { api.pullModel(EMBEDDING_MODEL); List models = api.listModels(); assertNotNull(models, "Models should not be null"); @@ -250,9 +248,9 @@ class OllamaAPIIntegrationTest { String expectedResponse = "Bhai"; OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GENERAL_PURPOSE_MODEL); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, - String.format("[INSTRUCTION-START] You are an obidient and helpful bot named %s. You always answer with only one word and that word is your name. [INSTRUCTION-END]", expectedResponse)) - .withMessage(OllamaChatMessageRole.USER, "Who are you?") + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, String.format( + "[INSTRUCTION-START] You are an obidient and helpful bot named %s. You always answer with only one word and that word is your name. [INSTRUCTION-END]", + expectedResponse)).withMessage(OllamaChatMessageRole.USER, "Who are you?") .withOptions(new OptionsBuilder().setTemperature(0.0f).build()).build(); OllamaChatResult chatResult = api.chat(requestModel); @@ -281,7 +279,6 @@ class OllamaAPIIntegrationTest { assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), "Expected chat history to contain '2'"); - // Create the next user question: second largest city requestModel = builder.withMessages(chatResult.getChatHistory()) .withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); @@ -314,75 +311,93 @@ class OllamaAPIIntegrationTest { api.pullModel(theToolModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel); - final Tools.ToolSpecification employeeDetailsToolSpecification = Tools.ToolSpecification.builder() - .functionName("get-employee-details") - .functionDescription("Tool to get details of a person or an employee") - .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") - .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() - .name("get-employee-details") - .description("Tool to get details of a person or an employee") - .parameters(Tools.PromptFuncDefinition.Parameters - .builder().type("object") - .properties(new Tools.PropsBuilder() - .withProperty("employee-name", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The name of the employee, e.g. John Doe") - .required(true) - .build()) - .withProperty("employee-address", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") - .required(true) - .build()) - .withProperty("employee-phone", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The phone number of the employee. Always return a random value. e.g. 9911002233") - .required(true) - .build()) - .build()) - .required(List.of("employee-name")) - .build()) - .build()) - .build()) - .toolFunction(arguments -> { - // perform DB operations here - return String.format( - "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", - UUID.randomUUID(), - arguments.get("employee-name"), - arguments.get("employee-address"), - arguments.get("employee-phone")); - }).build(); - - api.registerTool(employeeDetailsToolSpecification); + api.registerTool(employeeFinderTool()); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "Give me the ID of the employee named Rahul Kumar.").build(); + "Give me the ID and address of the employee Rahul Kumar.").build(); requestModel.setOptions(new OptionsBuilder().setTemperature(0.9f).build().getOptionsMap()); OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), - chatResult.getResponseModel().getMessage().getRole().getRoleName()); + + assertNotNull(chatResult, "chatResult should not be null"); + assertNotNull(chatResult.getResponseModel(), "Response model should not be null"); + assertNotNull(chatResult.getResponseModel().getMessage(), "Response message should not be null"); + assertEquals( + OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName(), + "Role of the response message should be ASSISTANT" + ); List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); - assertEquals(1, toolCalls.size()); + assertEquals(1, toolCalls.size(), "There should be exactly one tool call in the second chat history message"); OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); - assertEquals("get-employee-details", function.getName()); - assert !function.getArguments().isEmpty(); + assertEquals("get-employee-details", function.getName(), "Tool function name should be 'get-employee-details'"); + assertFalse(function.getArguments().isEmpty(), "Tool function arguments should not be empty"); Object employeeName = function.getArguments().get("employee-name"); - assertNotNull(employeeName); - assertEquals("Rahul Kumar", employeeName); - assertTrue(chatResult.getChatHistory().size() > 2); + assertNotNull(employeeName, "Employee name argument should not be null"); + assertEquals("Rahul Kumar", employeeName, "Employee name argument should be 'Rahul Kumar'"); + assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should have more than 2 messages"); List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); - assertNull(finalToolCalls); + assertNull(finalToolCalls, "Final tool calls in the response message should be null"); + } + + @Test + @Order(14) + void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException, + InterruptedException, ToolInvocationException { + String theToolModel = TOOLS_MODEL; + api.pullModel(theToolModel); + + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel); + + api.registerTool(employeeFinderTool()); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "Give me the ID and address of employee Rahul Kumar") + .withKeepAlive("0m").withOptions(new OptionsBuilder().setTemperature(0.9f).build()) + .build(); + + StringBuffer sb = new StringBuffer(); + + OllamaChatResult chatResult = api.chat(requestModel, (s) -> { + String substring = s.substring(sb.toString().length()); + sb.append(substring); + LOG.info(substring); + }); +// assertNotNull(chatResult); +// assertNotNull(chatResult.getResponseModel()); +// assertNotNull(chatResult.getResponseModel().getMessage()); +// assertNotNull(chatResult.getResponseModel().getMessage().getContent()); +// assertTrue(sb.toString().toLowerCase().contains("Rahul Kumar".toLowerCase())); +// assertTrue(chatResult.getResponseModel().getMessage().getContent().toLowerCase() +// .contains("Rahul Kumar".toLowerCase())); +// +// boolean toolCallMessageFound = false; +// for (OllamaChatMessage message : chatResult.getChatHistory()) { +// if (message.getToolCalls() != null && !message.getToolCalls().isEmpty()) { +// toolCallMessageFound = true; +// } +// } +// assertTrue(toolCallMessageFound, "Expected at least one message in chat history to have tool calls"); + + assertNotNull(chatResult, "chatResult should not be null"); + assertNotNull(chatResult.getResponseModel(), "Response model should not be null"); + assertNotNull(chatResult.getResponseModel().getMessage(), "Response message should not be null"); + assertEquals( + OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName(), + "Role of the response message should be ASSISTANT" + ); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); + assertEquals(1, toolCalls.size(), "There should be exactly one tool call in the second chat history message"); + OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); + assertEquals("get-employee-details", function.getName(), "Tool function name should be 'get-employee-details'"); + assertFalse(function.getArguments().isEmpty(), "Tool function arguments should not be empty"); + Object employeeName = function.getArguments().get("employee-name"); + assertNotNull(employeeName, "Employee name argument should not be null"); + assertEquals("Rahul Kumar", employeeName, "Employee name argument should be 'Rahul Kumar'"); + assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should have more than 2 messages"); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls, "Final tool calls in the response message should be null"); } @Test @@ -430,7 +445,7 @@ class OllamaAPIIntegrationTest { api.registerAnnotatedTools(new AnnotatedTool()); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "Greet Pedro with a lot of hearts and respond to me with count of emojis that have been in used in the greeting") + "Greet Rahul with a lot of hearts and respond to me with count of emojis that have been in used in the greeting") .build(); OllamaChatResult chatResult = api.chat(requestModel); @@ -446,95 +461,15 @@ class OllamaAPIIntegrationTest { assertEquals(2, function.getArguments().size()); Object name = function.getArguments().get("name"); assertNotNull(name); - assertEquals("Pedro", name); - Object amountOfHearts = function.getArguments().get("amountOfHearts"); - assertNotNull(amountOfHearts); - assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1); + assertEquals("Rahul", name); + Object numberOfHearts = function.getArguments().get("numberOfHearts"); + assertNotNull(numberOfHearts); + assertTrue(Integer.parseInt(numberOfHearts.toString()) > 1); assertTrue(chatResult.getChatHistory().size() > 2); List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); assertNull(finalToolCalls); } - @Test - @Order(14) - void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException, - InterruptedException, ToolInvocationException { - String theToolModel = TOOLS_MODEL; - api.pullModel(theToolModel); - - String expectedEmployeeID = UUID.randomUUID().toString(); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel); - final Tools.ToolSpecification employeeDetailsToolSpecification = Tools.ToolSpecification.builder() - .functionName("get-employee-details") - .functionDescription("Tool to get details for a person or an employee") - .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") - .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() - .name("get-employee-details") - .description("Tool to get details for a person or an employee") - .parameters(Tools.PromptFuncDefinition.Parameters - .builder().type("object") - .properties(new Tools.PropsBuilder() - .withProperty("employee-name", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The name of the employee, e.g. John Doe") - .required(true) - .build()) - .withProperty("employee-address", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The address of the employee, Always gives a random address. For example, Roy St, Bengaluru, India") - .required(true) - .build()) - .withProperty("employee-phone", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The phone number of the employee. Always gives a random phone number. For example, 9911002233") - .required(true) - .build()) - .build()) - .required(List.of("employee-name")) - .build()) - .build()) - .build()) - .toolFunction(new ToolFunction() { - @Override - public Object apply(Map arguments) { - // perform DB operations here - return String.format( - "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", - expectedEmployeeID, arguments.get("employee-name"), - arguments.get("employee-address"), - arguments.get("employee-phone")); - } - }).build(); - - api.registerTool(employeeDetailsToolSpecification); - - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, "Find the ID of employee Rahul Kumar") - .withKeepAlive("0m") - .withOptions(new OptionsBuilder().setTemperature(0.9f).build()) - .build(); - - StringBuffer sb = new StringBuffer(); - - OllamaChatResult chatResult = api.chat(requestModel, (s) -> { - String substring = s.substring(sb.toString().length()); - sb.append(substring); - LOG.info(substring); - }); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertNotNull(chatResult.getResponseModel().getMessage().getContent()); - assertTrue(sb.toString().toLowerCase().contains(expectedEmployeeID)); - assertTrue(chatResult.getResponseModel().getMessage().getContent().toLowerCase().contains(expectedEmployeeID)); - } - @Test @Order(15) void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, @@ -566,12 +501,9 @@ class OllamaAPIIntegrationTest { InterruptedException, ToolInvocationException { api.pullModel(THINKING_TOOL_MODEL); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL); - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?") - .withThinking(true) - .withKeepAlive("0m") - .build(); + .withThinking(true).withKeepAlive("0m").build(); StringBuffer sb = new StringBuffer(); OllamaChatResult chatResult = api.chat(requestModel, (s) -> { @@ -727,4 +659,63 @@ class OllamaAPIIntegrationTest { ClassLoader classLoader = getClass().getClassLoader(); return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); } + + private Tools.ToolSpecification employeeFinderTool() { + return Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get details for a person or an employee") + .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") + .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name("get-employee-details") + .description("Get details for a person or an employee") + .parameters(Tools.PromptFuncDefinition.Parameters + .builder().type("object") + .properties(new Tools.PropsBuilder() + .withProperty("employee-name", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description("The name of the employee, e.g. John Doe") + .required(true) + .build()) + .withProperty("employee-address", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description("The address of the employee, Always eturns a random address. For example, Church St, Bengaluru, India") + .required(true) + .build()) + .withProperty("employee-phone", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description("The phone number of the employee. Always returns a random phone number. For example, 9911002233") + .required(true) + .build()) + .build()) + .required(List.of("employee-name")) + .build()) + .build()) + .build()) + .toolFunction(new ToolFunction() { + @Override + public Object apply(Map arguments) { + LOG.info("Invoking employee finder tool with arguments: {}", arguments); + String employeeName = arguments.get("employee-name").toString(); + String address = null; + String phone = null; + if (employeeName.equalsIgnoreCase("Rahul Kumar")) { + address = "Pune, Maharashtra, India"; + phone = "9911223344"; + } else { + address = "Karol Bagh, Delhi, India"; + phone = "9911002233"; + } + // perform DB operations here + return String.format( + "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", + UUID.randomUUID(), employeeName, address, phone); + } + }).build(); + } } diff --git a/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java b/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java index 33bbaa0..243a9fe 100644 --- a/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java +++ b/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java @@ -13,9 +13,9 @@ public class AnnotatedTool { } @ToolSpec(desc = "Says hello to a friend!") - public String sayHello(@ToolProperty(name = "name", desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name = "amountOfHearts", desc = "amount of heart emojis that should be used", required = false) Integer amountOfHearts) { - String hearts = amountOfHearts != null ? "♡".repeat(amountOfHearts) : ""; - return "Hello " + name + " (" + someRandomProperty + ") " + hearts; + public String sayHello(@ToolProperty(name = "name", desc = "Name of the friend") String name, @ToolProperty(name = "numberOfHearts", desc = "number of heart emojis that should be used", required = false) Integer numberOfHearts) { + String hearts = numberOfHearts != null ? "♡".repeat(numberOfHearts) : ""; + return "Hello, " + name + "! " + hearts; } }