Extends chat API to automatically load registered Tools

This commit is contained in:
Markus Klenke 2024-12-04 09:12:55 +01:00 committed by Markus Klenke
parent 903a8176cd
commit e9c33ab0b2
3 changed files with 29 additions and 2 deletions

View File

@ -775,8 +775,8 @@ public class OllamaAPI {
result = requestCaller.callSync(request); result = requestCaller.callSync(request);
} }
//add registered Tools to Request // add all registered tools to Request
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
} }

View File

@ -1,5 +1,6 @@
package io.github.ollama4j.tools; package io.github.ollama4j.tools;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -14,4 +15,8 @@ public class ToolRegistry {
public void addTool (String name, Tools.ToolSpecification specification) { public void addTool (String name, Tools.ToolSpecification specification) {
tools.put(name, specification); tools.put(name, specification);
} }
public Collection<Tools.ToolSpecification> getRegisteredSpecs(){
return tools.values();
}
} }

View File

@ -225,6 +225,28 @@ class TestRealAPIs {
} }
} }
@Test
@Order(3)
void testChatWithTools() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank());
assertTrue(chatResult.getResponse().startsWith("NI"));
assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test @Test
@Order(3) @Order(3)
void testChatWithStream() { void testChatWithStream() {