forked from Mirror/ollama4j
Adds implicit tool calling for streamed chat requests (requires Ollama v0.4.6)
This commit is contained in:
@@ -777,22 +777,27 @@ public class OllamaAPI {
|
||||
result = requestCaller.call(request, streamHandler);
|
||||
} else {
|
||||
result = requestCaller.callSync(request);
|
||||
// check if toolCallIsWanted
|
||||
List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||
int toolCallTries = 0;
|
||||
while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries <3){
|
||||
for (OllamaChatToolCalls toolCall : toolCalls){
|
||||
String toolName = toolCall.getFunction().getName();
|
||||
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
|
||||
Map<String, Object> arguments = toolCall.getFunction().getArguments();
|
||||
Object res = toolFunction.apply(arguments);
|
||||
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[ToolCall-Result]" + toolName + "(" + arguments.keySet() +") : " + res + "[/ToolCall-Result]"));
|
||||
}
|
||||
result = requestCaller.callSync(request);
|
||||
toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||
toolCallTries++;
|
||||
}
|
||||
|
||||
// check if toolCallIsWanted
|
||||
List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||
int toolCallTries = 0;
|
||||
while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries <3){
|
||||
for (OllamaChatToolCalls toolCall : toolCalls){
|
||||
String toolName = toolCall.getFunction().getName();
|
||||
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
|
||||
Map<String, Object> arguments = toolCall.getFunction().getArguments();
|
||||
Object res = toolFunction.apply(arguments);
|
||||
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
|
||||
}
|
||||
|
||||
if (streamHandler != null) {
|
||||
result = requestCaller.call(request, streamHandler);
|
||||
} else {
|
||||
result = requestCaller.callSync(request);
|
||||
}
|
||||
toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||
toolCallTries++;
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
@@ -3,13 +3,10 @@ package io.github.ollama4j.models.request;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||
import io.github.ollama4j.models.chat.OllamaChatResult;
|
||||
import io.github.ollama4j.models.chat.*;
|
||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
||||
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
|
||||
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
|
||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||
import io.github.ollama4j.tools.Tools;
|
||||
import io.github.ollama4j.utils.Utils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -23,6 +20,7 @@ import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Specialization class for requests
|
||||
@@ -96,6 +94,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
||||
InputStream responseBodyStream = response.body();
|
||||
StringBuilder responseBuffer = new StringBuilder();
|
||||
OllamaChatResponseModel ollamaChatResponseModel = null;
|
||||
List<OllamaChatToolCalls> wantedToolsForStream = null;
|
||||
try (BufferedReader reader =
|
||||
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||
|
||||
@@ -120,6 +119,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
||||
} else {
|
||||
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
||||
ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
|
||||
if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){
|
||||
wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
|
||||
}
|
||||
if (finished && body.stream) {
|
||||
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
|
||||
break;
|
||||
@@ -131,6 +133,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
||||
LOG.error("Status code " + statusCode);
|
||||
throw new OllamaBaseException(responseBuffer.toString());
|
||||
} else {
|
||||
if(wantedToolsForStream != null) {
|
||||
ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
|
||||
}
|
||||
OllamaChatResult ollamaResult =
|
||||
new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
|
||||
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
|
||||
|
||||
Reference in New Issue
Block a user