Enables in chat tool calling

This commit is contained in:
Markus Klenke 2024-12-07 01:04:13 +01:00
parent b6a293add7
commit 69f6fd81cf
3 changed files with 23 additions and 5 deletions

View File

@ -777,6 +777,22 @@ public class OllamaAPI {
result = requestCaller.call(request, streamHandler); result = requestCaller.call(request, streamHandler);
} else { } else {
result = requestCaller.callSync(request); result = requestCaller.callSync(request);
// check if toolCallIsWanted
List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
int toolCallTries = 0;
while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries <3){
for (OllamaChatToolCalls toolCall : toolCalls){
String toolName = toolCall.getFunction().getName();
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
Map<String, Object> arguments = toolCall.getFunction().getArguments();
Object res = toolFunction.apply(arguments);
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[ToolCall-Result]" + toolName + "(" + arguments.keySet() +") : " + res + "[/ToolCall-Result]"));
}
result = requestCaller.callSync(request);
toolCalls = result.getResponseModel().getMessage().getToolCalls();
toolCallTries++;
}
} }
return result; return result;

View File

@ -12,5 +12,5 @@ import java.util.Map;
public class OllamaToolCallsFunction public class OllamaToolCallsFunction
{ {
private String name; private String name;
private Map<String,String> arguments; private Map<String,Object> arguments;
} }

View File

@ -273,14 +273,16 @@ class TestRealAPIs {
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName()); assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size()); assertEquals(1, toolCalls.size());
assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName()); assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName());
assertEquals(1, toolCalls.get(0).getFunction().getArguments().size()); assertEquals(1, toolCalls.get(0).getFunction().getArguments().size());
String employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name"); Object employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name");
assertNotNull(employeeName); assertNotNull(employeeName);
assertEquals("Rahul Kumar",employeeName); assertEquals("Rahul Kumar",employeeName);
assertEquals(2, chatResult.getChatHistory().size()); assertTrue(chatResult.getChatHistory().size()>2);
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
assertNull(finalToolCalls);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e); fail(e);
} }
@ -448,7 +450,7 @@ class DBQueryFunction implements ToolFunction {
@Override @Override
public Object apply(Map<String, Object> arguments) { public Object apply(Map<String, Object> arguments) {
// perform DB operations here // perform DB operations here
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString()); return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone"));
} }
} }