Adds implicit tool calling for streamed chat requests (requires Ollama v0.4.6)

This commit is contained in:
Markus Klenke
2024-12-09 23:07:25 +01:00
parent c4b7830614
commit 7ffbc5d3f2
3 changed files with 87 additions and 20 deletions

View File

@@ -265,7 +265,7 @@ class TestRealAPIs {
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
"Give me the details of the employee named 'Rahul Kumar'?")
"Give me the ID of the employee named 'Rahul Kumar'?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
@@ -288,6 +288,63 @@ class TestRealAPIs {
}
}
@Test
@Order(3)
void testChatWithToolsAndStream() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.toolPrompt(
Tools.PromptFuncDefinition.builder().type("function").function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get employee details from the database")
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
new Tools.PropsBuilder()
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
.build()
)
.required(List.of("employee-name"))
.build()
).build()
).build()
)
.toolFunction(new DBQueryFunction())
.build();
ollamaAPI.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
"Give me the ID of the employee named 'Rahul Kumar'?")
.build();
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testChatWithStream() {