Adds implicit tool calling for streamed chat requests (requires Ollama v0.4.6)

This commit is contained in:
Markus Klenke 2024-12-09 23:07:25 +01:00
parent c4b7830614
commit 7ffbc5d3f2
3 changed files with 87 additions and 20 deletions

View File

@ -777,6 +777,8 @@ public class OllamaAPI {
result = requestCaller.call(request, streamHandler); result = requestCaller.call(request, streamHandler);
} else { } else {
result = requestCaller.callSync(request); result = requestCaller.callSync(request);
}
// check if toolCallIsWanted // check if toolCallIsWanted
List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls(); List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
int toolCallTries = 0; int toolCallTries = 0;
@ -786,13 +788,16 @@ public class OllamaAPI {
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName); ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
Map<String, Object> arguments = toolCall.getFunction().getArguments(); Map<String, Object> arguments = toolCall.getFunction().getArguments();
Object res = toolFunction.apply(arguments); Object res = toolFunction.apply(arguments);
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[ToolCall-Result]" + toolName + "(" + arguments.keySet() +") : " + res + "[/ToolCall-Result]")); request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
}
result = requestCaller.callSync(request);
toolCalls = result.getResponseModel().getMessage().getToolCalls();
toolCallTries++;
} }
if (streamHandler != null) {
result = requestCaller.call(request, streamHandler);
} else {
result = requestCaller.callSync(request);
}
toolCalls = result.getResponseModel().getMessage().getToolCalls();
toolCallTries++;
} }
return result; return result;

View File

@ -3,13 +3,10 @@ package io.github.ollama4j.models.request;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.chat.OllamaChatMessage; import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.chat.OllamaChatRequest;
import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.models.response.OllamaErrorResponse; import io.github.ollama4j.models.response.OllamaErrorResponse;
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaStreamHandler;
import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -23,6 +20,7 @@ import java.net.http.HttpClient;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.List;
/** /**
* Specialization class for requests * Specialization class for requests
@ -96,6 +94,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
InputStream responseBodyStream = response.body(); InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder(); StringBuilder responseBuffer = new StringBuilder();
OllamaChatResponseModel ollamaChatResponseModel = null; OllamaChatResponseModel ollamaChatResponseModel = null;
List<OllamaChatToolCalls> wantedToolsForStream = null;
try (BufferedReader reader = try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
@ -120,6 +119,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
} else { } else {
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){
wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
}
if (finished && body.stream) { if (finished && body.stream) {
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString()); ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
break; break;
@ -131,6 +133,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
LOG.error("Status code " + statusCode); LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString()); throw new OllamaBaseException(responseBuffer.toString());
} else { } else {
if(wantedToolsForStream != null) {
ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
}
OllamaChatResult ollamaResult = OllamaChatResult ollamaResult =
new OllamaChatResult(ollamaChatResponseModel,body.getMessages()); new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
if (isVerbose()) LOG.info("Model response: " + ollamaResult); if (isVerbose()) LOG.info("Model response: " + ollamaResult);

View File

@ -265,7 +265,7 @@ class TestRealAPIs {
OllamaChatRequest requestModel = builder OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, .withMessage(OllamaChatMessageRole.USER,
"Give me the details of the employee named 'Rahul Kumar'?") "Give me the ID of the employee named 'Rahul Kumar'?")
.build(); .build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
@ -288,6 +288,63 @@ class TestRealAPIs {
} }
} }
@Test
@Order(3)
void testChatWithToolsAndStream() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.toolPrompt(
Tools.PromptFuncDefinition.builder().type("function").function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get employee details from the database")
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
new Tools.PropsBuilder()
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
.build()
)
.required(List.of("employee-name"))
.build()
).build()
).build()
)
.toolFunction(new DBQueryFunction())
.build();
ollamaAPI.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
"Give me the ID of the employee named 'Rahul Kumar'?")
.build();
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test @Test
@Order(3) @Order(3)
void testChatWithStream() { void testChatWithStream() {