Refactor tool tests and improve tool argument handling

Refactored integration tests to use a reusable employeeFinderTool method and improved assertions for tool call results. Updated tool argument formatting in OllamaAPI for clearer output. Modified AnnotatedTool to use 'numberOfHearts' instead of 'amountOfHearts' and simplified the sayHello method signature and output. Removed redundant and duplicate test code for tool streaming.
This commit is contained in:
amithkoujalgi 2025-08-30 20:53:14 +05:30
parent 6078db6157
commit be5b77c4ac
No known key found for this signature in database
GPG Key ID: E29A37746AF94B70
3 changed files with 156 additions and 162 deletions

View File

@ -1294,8 +1294,11 @@ public class OllamaAPI {
}
Map<String, Object> arguments = toolCall.getFunction().getArguments();
Object res = toolFunction.apply(arguments);
String argumentKeys = arguments.keySet().stream()
.map(Object::toString)
.collect(Collectors.joining(", "));
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,
"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]"));
"[TOOL_RESULTS] " + toolName + "(" + argumentKeys + "): " + res + " [/TOOL_RESULTS]"));
}
if (tokenHandler != null) {

View File

@ -105,8 +105,7 @@ class OllamaAPIIntegrationTest {
@Test
@Order(2)
void testListModelsAPI()
throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
// Fetch the list of models
List<Model> models = api.listModels();
// Assert that the models list is not null
@ -126,8 +125,7 @@ class OllamaAPIIntegrationTest {
@Test
@Order(3)
void testPullModelAPI()
throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
api.pullModel(EMBEDDING_MODEL);
List<Model> models = api.listModels();
assertNotNull(models, "Models should not be null");
@ -250,9 +248,9 @@ class OllamaAPIIntegrationTest {
String expectedResponse = "Bhai";
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GENERAL_PURPOSE_MODEL);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
String.format("[INSTRUCTION-START] You are an obidient and helpful bot named %s. You always answer with only one word and that word is your name. [INSTRUCTION-END]", expectedResponse))
.withMessage(OllamaChatMessageRole.USER, "Who are you?")
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, String.format(
"[INSTRUCTION-START] You are an obidient and helpful bot named %s. You always answer with only one word and that word is your name. [INSTRUCTION-END]",
expectedResponse)).withMessage(OllamaChatMessageRole.USER, "Who are you?")
.withOptions(new OptionsBuilder().setTemperature(0.0f).build()).build();
OllamaChatResult chatResult = api.chat(requestModel);
@ -281,7 +279,6 @@ class OllamaAPIIntegrationTest {
assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")),
"Expected chat history to contain '2'");
// Create the next user question: second largest city
requestModel = builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build();
@ -314,75 +311,93 @@ class OllamaAPIIntegrationTest {
api.pullModel(theToolModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
final Tools.ToolSpecification employeeDetailsToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Tool to get details of a person or an employee")
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Tool to get details of a person or an employee")
.parameters(Tools.PromptFuncDefinition.Parameters
.builder().type("object")
.properties(new Tools.PropsBuilder()
.withProperty("employee-name",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The name of the employee, e.g. John Doe")
.required(true)
.build())
.withProperty("employee-address",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
.required(true)
.build())
.withProperty("employee-phone",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The phone number of the employee. Always return a random value. e.g. 9911002233")
.required(true)
.build())
.build())
.required(List.of("employee-name"))
.build())
.build())
.build())
.toolFunction(arguments -> {
// perform DB operations here
return String.format(
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
UUID.randomUUID(),
arguments.get("employee-name"),
arguments.get("employee-address"),
arguments.get("employee-phone"));
}).build();
api.registerTool(employeeDetailsToolSpecification);
api.registerTool(employeeFinderTool());
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"Give me the ID of the employee named Rahul Kumar.").build();
"Give me the ID and address of the employee Rahul Kumar.").build();
requestModel.setOptions(new OptionsBuilder().setTemperature(0.9f).build().getOptionsMap());
OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),
chatResult.getResponseModel().getMessage().getRole().getRoleName());
assertNotNull(chatResult, "chatResult should not be null");
assertNotNull(chatResult.getResponseModel(), "Response model should not be null");
assertNotNull(chatResult.getResponseModel().getMessage(), "Response message should not be null");
assertEquals(
OllamaChatMessageRole.ASSISTANT.getRoleName(),
chatResult.getResponseModel().getMessage().getRole().getRoleName(),
"Role of the response message should be ASSISTANT"
);
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size());
assertEquals(1, toolCalls.size(), "There should be exactly one tool call in the second chat history message");
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
assertEquals("get-employee-details", function.getName());
assert !function.getArguments().isEmpty();
assertEquals("get-employee-details", function.getName(), "Tool function name should be 'get-employee-details'");
assertFalse(function.getArguments().isEmpty(), "Tool function arguments should not be empty");
Object employeeName = function.getArguments().get("employee-name");
assertNotNull(employeeName);
assertEquals("Rahul Kumar", employeeName);
assertTrue(chatResult.getChatHistory().size() > 2);
assertNotNull(employeeName, "Employee name argument should not be null");
assertEquals("Rahul Kumar", employeeName, "Employee name argument should be 'Rahul Kumar'");
assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should have more than 2 messages");
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
assertNull(finalToolCalls);
assertNull(finalToolCalls, "Final tool calls in the response message should be null");
}
@Test
@Order(14)
void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException,
InterruptedException, ToolInvocationException {
String theToolModel = TOOLS_MODEL;
api.pullModel(theToolModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
api.registerTool(employeeFinderTool());
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "Give me the ID and address of employee Rahul Kumar")
.withKeepAlive("0m").withOptions(new OptionsBuilder().setTemperature(0.9f).build())
.build();
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
String substring = s.substring(sb.toString().length());
sb.append(substring);
LOG.info(substring);
});
// assertNotNull(chatResult);
// assertNotNull(chatResult.getResponseModel());
// assertNotNull(chatResult.getResponseModel().getMessage());
// assertNotNull(chatResult.getResponseModel().getMessage().getContent());
// assertTrue(sb.toString().toLowerCase().contains("Rahul Kumar".toLowerCase()));
// assertTrue(chatResult.getResponseModel().getMessage().getContent().toLowerCase()
// .contains("Rahul Kumar".toLowerCase()));
//
// boolean toolCallMessageFound = false;
// for (OllamaChatMessage message : chatResult.getChatHistory()) {
// if (message.getToolCalls() != null && !message.getToolCalls().isEmpty()) {
// toolCallMessageFound = true;
// }
// }
// assertTrue(toolCallMessageFound, "Expected at least one message in chat history to have tool calls");
assertNotNull(chatResult, "chatResult should not be null");
assertNotNull(chatResult.getResponseModel(), "Response model should not be null");
assertNotNull(chatResult.getResponseModel().getMessage(), "Response message should not be null");
assertEquals(
OllamaChatMessageRole.ASSISTANT.getRoleName(),
chatResult.getResponseModel().getMessage().getRole().getRoleName(),
"Role of the response message should be ASSISTANT"
);
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size(), "There should be exactly one tool call in the second chat history message");
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
assertEquals("get-employee-details", function.getName(), "Tool function name should be 'get-employee-details'");
assertFalse(function.getArguments().isEmpty(), "Tool function arguments should not be empty");
Object employeeName = function.getArguments().get("employee-name");
assertNotNull(employeeName, "Employee name argument should not be null");
assertEquals("Rahul Kumar", employeeName, "Employee name argument should be 'Rahul Kumar'");
assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should have more than 2 messages");
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
assertNull(finalToolCalls, "Final tool calls in the response message should be null");
}
@Test
@ -430,7 +445,7 @@ class OllamaAPIIntegrationTest {
api.registerAnnotatedTools(new AnnotatedTool());
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"Greet Pedro with a lot of hearts and respond to me with count of emojis that have been in used in the greeting")
"Greet Rahul with a lot of hearts and respond to me with count of emojis that have been in used in the greeting")
.build();
OllamaChatResult chatResult = api.chat(requestModel);
@ -446,95 +461,15 @@ class OllamaAPIIntegrationTest {
assertEquals(2, function.getArguments().size());
Object name = function.getArguments().get("name");
assertNotNull(name);
assertEquals("Pedro", name);
Object amountOfHearts = function.getArguments().get("amountOfHearts");
assertNotNull(amountOfHearts);
assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1);
assertEquals("Rahul", name);
Object numberOfHearts = function.getArguments().get("numberOfHearts");
assertNotNull(numberOfHearts);
assertTrue(Integer.parseInt(numberOfHearts.toString()) > 1);
assertTrue(chatResult.getChatHistory().size() > 2);
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
assertNull(finalToolCalls);
}
@Test
@Order(14)
void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException,
InterruptedException, ToolInvocationException {
String theToolModel = TOOLS_MODEL;
api.pullModel(theToolModel);
String expectedEmployeeID = UUID.randomUUID().toString();
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
final Tools.ToolSpecification employeeDetailsToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Tool to get details for a person or an employee")
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Tool to get details for a person or an employee")
.parameters(Tools.PromptFuncDefinition.Parameters
.builder().type("object")
.properties(new Tools.PropsBuilder()
.withProperty("employee-name",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The name of the employee, e.g. John Doe")
.required(true)
.build())
.withProperty("employee-address",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The address of the employee, Always gives a random address. For example, Roy St, Bengaluru, India")
.required(true)
.build())
.withProperty("employee-phone",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The phone number of the employee. Always gives a random phone number. For example, 9911002233")
.required(true)
.build())
.build())
.required(List.of("employee-name"))
.build())
.build())
.build())
.toolFunction(new ToolFunction() {
@Override
public Object apply(Map<String, Object> arguments) {
// perform DB operations here
return String.format(
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
expectedEmployeeID, arguments.get("employee-name"),
arguments.get("employee-address"),
arguments.get("employee-phone"));
}
}).build();
api.registerTool(employeeDetailsToolSpecification);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "Find the ID of employee Rahul Kumar")
.withKeepAlive("0m")
.withOptions(new OptionsBuilder().setTemperature(0.9f).build())
.build();
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
String substring = s.substring(sb.toString().length());
sb.append(substring);
LOG.info(substring);
});
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertTrue(sb.toString().toLowerCase().contains(expectedEmployeeID));
assertTrue(chatResult.getResponseModel().getMessage().getContent().toLowerCase().contains(expectedEmployeeID));
}
@Test
@Order(15)
void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException,
@ -566,12 +501,9 @@ class OllamaAPIIntegrationTest {
InterruptedException, ToolInvocationException {
api.pullModel(THINKING_TOOL_MODEL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.withThinking(true)
.withKeepAlive("0m")
.build();
.withThinking(true).withKeepAlive("0m").build();
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
@ -727,4 +659,63 @@ class OllamaAPIIntegrationTest {
ClassLoader classLoader = getClass().getClassLoader();
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
}
private Tools.ToolSpecification employeeFinderTool() {
return Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get details for a person or an employee")
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get details for a person or an employee")
.parameters(Tools.PromptFuncDefinition.Parameters
.builder().type("object")
.properties(new Tools.PropsBuilder()
.withProperty("employee-name",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The name of the employee, e.g. John Doe")
.required(true)
.build())
.withProperty("employee-address",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The address of the employee, Always eturns a random address. For example, Church St, Bengaluru, India")
.required(true)
.build())
.withProperty("employee-phone",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The phone number of the employee. Always returns a random phone number. For example, 9911002233")
.required(true)
.build())
.build())
.required(List.of("employee-name"))
.build())
.build())
.build())
.toolFunction(new ToolFunction() {
@Override
public Object apply(Map<String, Object> arguments) {
LOG.info("Invoking employee finder tool with arguments: {}", arguments);
String employeeName = arguments.get("employee-name").toString();
String address = null;
String phone = null;
if (employeeName.equalsIgnoreCase("Rahul Kumar")) {
address = "Pune, Maharashtra, India";
phone = "9911223344";
} else {
address = "Karol Bagh, Delhi, India";
phone = "9911002233";
}
// perform DB operations here
return String.format(
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
UUID.randomUUID(), employeeName, address, phone);
}
}).build();
}
}

View File

@ -13,9 +13,9 @@ public class AnnotatedTool {
}
@ToolSpec(desc = "Says hello to a friend!")
public String sayHello(@ToolProperty(name = "name", desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name = "amountOfHearts", desc = "amount of heart emojis that should be used", required = false) Integer amountOfHearts) {
String hearts = amountOfHearts != null ? "".repeat(amountOfHearts) : "";
return "Hello " + name + " (" + someRandomProperty + ") " + hearts;
public String sayHello(@ToolProperty(name = "name", desc = "Name of the friend") String name, @ToolProperty(name = "numberOfHearts", desc = "number of heart emojis that should be used", required = false) Integer numberOfHearts) {
String hearts = numberOfHearts != null ? "".repeat(numberOfHearts) : "";
return "Hello, " + name + "! " + hearts;
}
}