From be5b77c4ac1c064120d2a9b77aa3fb2f73f007a3 Mon Sep 17 00:00:00 2001 From: amithkoujalgi Date: Sat, 30 Aug 2025 20:53:14 +0530 Subject: [PATCH] Refactor tool tests and improve tool argument handling Refactored integration tests to use a reusable employeeFinderTool method and improved assertions for tool call results. Updated tool argument formatting in OllamaAPI for clearer output. Modified AnnotatedTool to use 'numberOfHearts' instead of 'amountOfHearts' and simplified the sayHello method signature and output. Removed redundant and duplicate test code for tool streaming. --- .../java/io/github/ollama4j/OllamaAPI.java | 5 +- .../OllamaAPIIntegrationTest.java | 307 +++++++++--------- .../ollama4j/samples/AnnotatedTool.java | 6 +- 3 files changed, 156 insertions(+), 162 deletions(-) diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index ba2488a..fdab795 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -1294,8 +1294,11 @@ public class OllamaAPI { } Map arguments = toolCall.getFunction().getArguments(); Object res = toolFunction.apply(arguments); + String argumentKeys = arguments.keySet().stream() + .map(Object::toString) + .collect(Collectors.joining(", ")); request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, - "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]")); + "[TOOL_RESULTS] " + toolName + "(" + argumentKeys + "): " + res + " [/TOOL_RESULTS]")); } if (tokenHandler != null) { diff --git a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java index b720a24..c152588 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java +++ b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java @@ -105,8 +105,7 @@ class OllamaAPIIntegrationTest { @Test @Order(2) - void testListModelsAPI() - throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { // Fetch the list of models List models = api.listModels(); // Assert that the models list is not null @@ -126,8 +125,7 @@ class OllamaAPIIntegrationTest { @Test @Order(3) - void testPullModelAPI() - throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { api.pullModel(EMBEDDING_MODEL); List models = api.listModels(); assertNotNull(models, "Models should not be null"); @@ -250,9 +248,9 @@ class OllamaAPIIntegrationTest { String expectedResponse = "Bhai"; OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GENERAL_PURPOSE_MODEL); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, - String.format("[INSTRUCTION-START] You are an obidient and helpful bot named %s. You always answer with only one word and that word is your name. [INSTRUCTION-END]", expectedResponse)) - .withMessage(OllamaChatMessageRole.USER, "Who are you?") + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, String.format( + "[INSTRUCTION-START] You are an obidient and helpful bot named %s. You always answer with only one word and that word is your name. [INSTRUCTION-END]", + expectedResponse)).withMessage(OllamaChatMessageRole.USER, "Who are you?") .withOptions(new OptionsBuilder().setTemperature(0.0f).build()).build(); OllamaChatResult chatResult = api.chat(requestModel); @@ -281,7 +279,6 @@ class OllamaAPIIntegrationTest { assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), "Expected chat history to contain '2'"); - // Create the next user question: second largest city requestModel = builder.withMessages(chatResult.getChatHistory()) .withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); @@ -314,75 +311,93 @@ class OllamaAPIIntegrationTest { api.pullModel(theToolModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel); - final Tools.ToolSpecification employeeDetailsToolSpecification = Tools.ToolSpecification.builder() - .functionName("get-employee-details") - .functionDescription("Tool to get details of a person or an employee") - .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") - .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() - .name("get-employee-details") - .description("Tool to get details of a person or an employee") - .parameters(Tools.PromptFuncDefinition.Parameters - .builder().type("object") - .properties(new Tools.PropsBuilder() - .withProperty("employee-name", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The name of the employee, e.g. John Doe") - .required(true) - .build()) - .withProperty("employee-address", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") - .required(true) - .build()) - .withProperty("employee-phone", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The phone number of the employee. Always return a random value. e.g. 9911002233") - .required(true) - .build()) - .build()) - .required(List.of("employee-name")) - .build()) - .build()) - .build()) - .toolFunction(arguments -> { - // perform DB operations here - return String.format( - "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", - UUID.randomUUID(), - arguments.get("employee-name"), - arguments.get("employee-address"), - arguments.get("employee-phone")); - }).build(); - - api.registerTool(employeeDetailsToolSpecification); + api.registerTool(employeeFinderTool()); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "Give me the ID of the employee named Rahul Kumar.").build(); + "Give me the ID and address of the employee Rahul Kumar.").build(); requestModel.setOptions(new OptionsBuilder().setTemperature(0.9f).build().getOptionsMap()); OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), - chatResult.getResponseModel().getMessage().getRole().getRoleName()); + + assertNotNull(chatResult, "chatResult should not be null"); + assertNotNull(chatResult.getResponseModel(), "Response model should not be null"); + assertNotNull(chatResult.getResponseModel().getMessage(), "Response message should not be null"); + assertEquals( + OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName(), + "Role of the response message should be ASSISTANT" + ); List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); - assertEquals(1, toolCalls.size()); + assertEquals(1, toolCalls.size(), "There should be exactly one tool call in the second chat history message"); OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); - assertEquals("get-employee-details", function.getName()); - assert !function.getArguments().isEmpty(); + assertEquals("get-employee-details", function.getName(), "Tool function name should be 'get-employee-details'"); + assertFalse(function.getArguments().isEmpty(), "Tool function arguments should not be empty"); Object employeeName = function.getArguments().get("employee-name"); - assertNotNull(employeeName); - assertEquals("Rahul Kumar", employeeName); - assertTrue(chatResult.getChatHistory().size() > 2); + assertNotNull(employeeName, "Employee name argument should not be null"); + assertEquals("Rahul Kumar", employeeName, "Employee name argument should be 'Rahul Kumar'"); + assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should have more than 2 messages"); List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); - assertNull(finalToolCalls); + assertNull(finalToolCalls, "Final tool calls in the response message should be null"); + } + + @Test + @Order(14) + void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException, + InterruptedException, ToolInvocationException { + String theToolModel = TOOLS_MODEL; + api.pullModel(theToolModel); + + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel); + + api.registerTool(employeeFinderTool()); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "Give me the ID and address of employee Rahul Kumar") + .withKeepAlive("0m").withOptions(new OptionsBuilder().setTemperature(0.9f).build()) + .build(); + + StringBuffer sb = new StringBuffer(); + + OllamaChatResult chatResult = api.chat(requestModel, (s) -> { + String substring = s.substring(sb.toString().length()); + sb.append(substring); + LOG.info(substring); + }); +// assertNotNull(chatResult); +// assertNotNull(chatResult.getResponseModel()); +// assertNotNull(chatResult.getResponseModel().getMessage()); +// assertNotNull(chatResult.getResponseModel().getMessage().getContent()); +// assertTrue(sb.toString().toLowerCase().contains("Rahul Kumar".toLowerCase())); +// assertTrue(chatResult.getResponseModel().getMessage().getContent().toLowerCase() +// .contains("Rahul Kumar".toLowerCase())); +// +// boolean toolCallMessageFound = false; +// for (OllamaChatMessage message : chatResult.getChatHistory()) { +// if (message.getToolCalls() != null && !message.getToolCalls().isEmpty()) { +// toolCallMessageFound = true; +// } +// } +// assertTrue(toolCallMessageFound, "Expected at least one message in chat history to have tool calls"); + + assertNotNull(chatResult, "chatResult should not be null"); + assertNotNull(chatResult.getResponseModel(), "Response model should not be null"); + assertNotNull(chatResult.getResponseModel().getMessage(), "Response message should not be null"); + assertEquals( + OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName(), + "Role of the response message should be ASSISTANT" + ); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); + assertEquals(1, toolCalls.size(), "There should be exactly one tool call in the second chat history message"); + OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); + assertEquals("get-employee-details", function.getName(), "Tool function name should be 'get-employee-details'"); + assertFalse(function.getArguments().isEmpty(), "Tool function arguments should not be empty"); + Object employeeName = function.getArguments().get("employee-name"); + assertNotNull(employeeName, "Employee name argument should not be null"); + assertEquals("Rahul Kumar", employeeName, "Employee name argument should be 'Rahul Kumar'"); + assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should have more than 2 messages"); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls, "Final tool calls in the response message should be null"); } @Test @@ -430,7 +445,7 @@ class OllamaAPIIntegrationTest { api.registerAnnotatedTools(new AnnotatedTool()); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "Greet Pedro with a lot of hearts and respond to me with count of emojis that have been in used in the greeting") + "Greet Rahul with a lot of hearts and respond to me with count of emojis that have been in used in the greeting") .build(); OllamaChatResult chatResult = api.chat(requestModel); @@ -446,95 +461,15 @@ class OllamaAPIIntegrationTest { assertEquals(2, function.getArguments().size()); Object name = function.getArguments().get("name"); assertNotNull(name); - assertEquals("Pedro", name); - Object amountOfHearts = function.getArguments().get("amountOfHearts"); - assertNotNull(amountOfHearts); - assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1); + assertEquals("Rahul", name); + Object numberOfHearts = function.getArguments().get("numberOfHearts"); + assertNotNull(numberOfHearts); + assertTrue(Integer.parseInt(numberOfHearts.toString()) > 1); assertTrue(chatResult.getChatHistory().size() > 2); List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); assertNull(finalToolCalls); } - @Test - @Order(14) - void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException, - InterruptedException, ToolInvocationException { - String theToolModel = TOOLS_MODEL; - api.pullModel(theToolModel); - - String expectedEmployeeID = UUID.randomUUID().toString(); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel); - final Tools.ToolSpecification employeeDetailsToolSpecification = Tools.ToolSpecification.builder() - .functionName("get-employee-details") - .functionDescription("Tool to get details for a person or an employee") - .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") - .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() - .name("get-employee-details") - .description("Tool to get details for a person or an employee") - .parameters(Tools.PromptFuncDefinition.Parameters - .builder().type("object") - .properties(new Tools.PropsBuilder() - .withProperty("employee-name", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The name of the employee, e.g. John Doe") - .required(true) - .build()) - .withProperty("employee-address", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The address of the employee, Always gives a random address. For example, Roy St, Bengaluru, India") - .required(true) - .build()) - .withProperty("employee-phone", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The phone number of the employee. Always gives a random phone number. For example, 9911002233") - .required(true) - .build()) - .build()) - .required(List.of("employee-name")) - .build()) - .build()) - .build()) - .toolFunction(new ToolFunction() { - @Override - public Object apply(Map arguments) { - // perform DB operations here - return String.format( - "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", - expectedEmployeeID, arguments.get("employee-name"), - arguments.get("employee-address"), - arguments.get("employee-phone")); - } - }).build(); - - api.registerTool(employeeDetailsToolSpecification); - - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, "Find the ID of employee Rahul Kumar") - .withKeepAlive("0m") - .withOptions(new OptionsBuilder().setTemperature(0.9f).build()) - .build(); - - StringBuffer sb = new StringBuffer(); - - OllamaChatResult chatResult = api.chat(requestModel, (s) -> { - String substring = s.substring(sb.toString().length()); - sb.append(substring); - LOG.info(substring); - }); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertNotNull(chatResult.getResponseModel().getMessage().getContent()); - assertTrue(sb.toString().toLowerCase().contains(expectedEmployeeID)); - assertTrue(chatResult.getResponseModel().getMessage().getContent().toLowerCase().contains(expectedEmployeeID)); - } - @Test @Order(15) void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, @@ -566,12 +501,9 @@ class OllamaAPIIntegrationTest { InterruptedException, ToolInvocationException { api.pullModel(THINKING_TOOL_MODEL); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL); - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?") - .withThinking(true) - .withKeepAlive("0m") - .build(); + .withThinking(true).withKeepAlive("0m").build(); StringBuffer sb = new StringBuffer(); OllamaChatResult chatResult = api.chat(requestModel, (s) -> { @@ -727,4 +659,63 @@ class OllamaAPIIntegrationTest { ClassLoader classLoader = getClass().getClassLoader(); return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); } + + private Tools.ToolSpecification employeeFinderTool() { + return Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get details for a person or an employee") + .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") + .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name("get-employee-details") + .description("Get details for a person or an employee") + .parameters(Tools.PromptFuncDefinition.Parameters + .builder().type("object") + .properties(new Tools.PropsBuilder() + .withProperty("employee-name", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description("The name of the employee, e.g. John Doe") + .required(true) + .build()) + .withProperty("employee-address", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description("The address of the employee, Always eturns a random address. For example, Church St, Bengaluru, India") + .required(true) + .build()) + .withProperty("employee-phone", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description("The phone number of the employee. Always returns a random phone number. For example, 9911002233") + .required(true) + .build()) + .build()) + .required(List.of("employee-name")) + .build()) + .build()) + .build()) + .toolFunction(new ToolFunction() { + @Override + public Object apply(Map arguments) { + LOG.info("Invoking employee finder tool with arguments: {}", arguments); + String employeeName = arguments.get("employee-name").toString(); + String address = null; + String phone = null; + if (employeeName.equalsIgnoreCase("Rahul Kumar")) { + address = "Pune, Maharashtra, India"; + phone = "9911223344"; + } else { + address = "Karol Bagh, Delhi, India"; + phone = "9911002233"; + } + // perform DB operations here + return String.format( + "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", + UUID.randomUUID(), employeeName, address, phone); + } + }).build(); + } } diff --git a/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java b/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java index 33bbaa0..243a9fe 100644 --- a/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java +++ b/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java @@ -13,9 +13,9 @@ public class AnnotatedTool { } @ToolSpec(desc = "Says hello to a friend!") - public String sayHello(@ToolProperty(name = "name", desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name = "amountOfHearts", desc = "amount of heart emojis that should be used", required = false) Integer amountOfHearts) { - String hearts = amountOfHearts != null ? "♡".repeat(amountOfHearts) : ""; - return "Hello " + name + " (" + someRandomProperty + ") " + hearts; + public String sayHello(@ToolProperty(name = "name", desc = "Name of the friend") String name, @ToolProperty(name = "numberOfHearts", desc = "number of heart emojis that should be used", required = false) Integer numberOfHearts) { + String hearts = numberOfHearts != null ? "♡".repeat(numberOfHearts) : ""; + return "Hello, " + name + "! " + hearts; } }