mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 20:07:10 +02:00
Enables in chat tool calling
This commit is contained in:
parent
b6a293add7
commit
69f6fd81cf
@ -777,6 +777,22 @@ public class OllamaAPI {
|
|||||||
result = requestCaller.call(request, streamHandler);
|
result = requestCaller.call(request, streamHandler);
|
||||||
} else {
|
} else {
|
||||||
result = requestCaller.callSync(request);
|
result = requestCaller.callSync(request);
|
||||||
|
// check if toolCallIsWanted
|
||||||
|
List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||||
|
int toolCallTries = 0;
|
||||||
|
while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries <3){
|
||||||
|
for (OllamaChatToolCalls toolCall : toolCalls){
|
||||||
|
String toolName = toolCall.getFunction().getName();
|
||||||
|
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
|
||||||
|
Map<String, Object> arguments = toolCall.getFunction().getArguments();
|
||||||
|
Object res = toolFunction.apply(arguments);
|
||||||
|
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[ToolCall-Result]" + toolName + "(" + arguments.keySet() +") : " + res + "[/ToolCall-Result]"));
|
||||||
|
}
|
||||||
|
result = requestCaller.callSync(request);
|
||||||
|
toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||||
|
toolCallTries++;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -12,5 +12,5 @@ import java.util.Map;
|
|||||||
public class OllamaToolCallsFunction
|
public class OllamaToolCallsFunction
|
||||||
{
|
{
|
||||||
private String name;
|
private String name;
|
||||||
private Map<String,String> arguments;
|
private Map<String,Object> arguments;
|
||||||
}
|
}
|
||||||
|
@ -273,14 +273,16 @@ class TestRealAPIs {
|
|||||||
assertNotNull(chatResult.getResponseModel());
|
assertNotNull(chatResult.getResponseModel());
|
||||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||||
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||||
List<OllamaChatToolCalls> toolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||||
assertEquals(1, toolCalls.size());
|
assertEquals(1, toolCalls.size());
|
||||||
assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName());
|
assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName());
|
||||||
assertEquals(1, toolCalls.get(0).getFunction().getArguments().size());
|
assertEquals(1, toolCalls.get(0).getFunction().getArguments().size());
|
||||||
String employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name");
|
Object employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name");
|
||||||
assertNotNull(employeeName);
|
assertNotNull(employeeName);
|
||||||
assertEquals("Rahul Kumar",employeeName);
|
assertEquals("Rahul Kumar",employeeName);
|
||||||
assertEquals(2, chatResult.getChatHistory().size());
|
assertTrue(chatResult.getChatHistory().size()>2);
|
||||||
|
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||||
|
assertNull(finalToolCalls);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
fail(e);
|
fail(e);
|
||||||
}
|
}
|
||||||
@ -448,7 +450,7 @@ class DBQueryFunction implements ToolFunction {
|
|||||||
@Override
|
@Override
|
||||||
public Object apply(Map<String, Object> arguments) {
|
public Object apply(Map<String, Object> arguments) {
|
||||||
// perform DB operations here
|
// perform DB operations here
|
||||||
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString());
|
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user