mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 20:07:10 +02:00
Adds implicit tool calling for streamed chat requests (requires Ollama v0.4.6)
This commit is contained in:
parent
c4b7830614
commit
7ffbc5d3f2
@ -777,6 +777,8 @@ public class OllamaAPI {
|
|||||||
result = requestCaller.call(request, streamHandler);
|
result = requestCaller.call(request, streamHandler);
|
||||||
} else {
|
} else {
|
||||||
result = requestCaller.callSync(request);
|
result = requestCaller.callSync(request);
|
||||||
|
}
|
||||||
|
|
||||||
// check if toolCallIsWanted
|
// check if toolCallIsWanted
|
||||||
List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||||
int toolCallTries = 0;
|
int toolCallTries = 0;
|
||||||
@ -786,13 +788,16 @@ public class OllamaAPI {
|
|||||||
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
|
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
|
||||||
Map<String, Object> arguments = toolCall.getFunction().getArguments();
|
Map<String, Object> arguments = toolCall.getFunction().getArguments();
|
||||||
Object res = toolFunction.apply(arguments);
|
Object res = toolFunction.apply(arguments);
|
||||||
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[ToolCall-Result]" + toolName + "(" + arguments.keySet() +") : " + res + "[/ToolCall-Result]"));
|
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
|
||||||
}
|
|
||||||
result = requestCaller.callSync(request);
|
|
||||||
toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
|
||||||
toolCallTries++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (streamHandler != null) {
|
||||||
|
result = requestCaller.call(request, streamHandler);
|
||||||
|
} else {
|
||||||
|
result = requestCaller.callSync(request);
|
||||||
|
}
|
||||||
|
toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||||
|
toolCallTries++;
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -3,13 +3,10 @@ package io.github.ollama4j.models.request;
|
|||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import com.fasterxml.jackson.core.type.TypeReference;
|
import com.fasterxml.jackson.core.type.TypeReference;
|
||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
|
import io.github.ollama4j.models.chat.*;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
|
||||||
import io.github.ollama4j.models.chat.OllamaChatResult;
|
|
||||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
|
|
||||||
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
|
|
||||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||||
|
import io.github.ollama4j.tools.Tools;
|
||||||
import io.github.ollama4j.utils.Utils;
|
import io.github.ollama4j.utils.Utils;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
@ -23,6 +20,7 @@ import java.net.http.HttpClient;
|
|||||||
import java.net.http.HttpRequest;
|
import java.net.http.HttpRequest;
|
||||||
import java.net.http.HttpResponse;
|
import java.net.http.HttpResponse;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specialization class for requests
|
* Specialization class for requests
|
||||||
@ -96,6 +94,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
|||||||
InputStream responseBodyStream = response.body();
|
InputStream responseBodyStream = response.body();
|
||||||
StringBuilder responseBuffer = new StringBuilder();
|
StringBuilder responseBuffer = new StringBuilder();
|
||||||
OllamaChatResponseModel ollamaChatResponseModel = null;
|
OllamaChatResponseModel ollamaChatResponseModel = null;
|
||||||
|
List<OllamaChatToolCalls> wantedToolsForStream = null;
|
||||||
try (BufferedReader reader =
|
try (BufferedReader reader =
|
||||||
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||||
|
|
||||||
@ -120,6 +119,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
|||||||
} else {
|
} else {
|
||||||
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
||||||
ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
|
ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
|
||||||
|
if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){
|
||||||
|
wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
|
||||||
|
}
|
||||||
if (finished && body.stream) {
|
if (finished && body.stream) {
|
||||||
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
|
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
|
||||||
break;
|
break;
|
||||||
@ -131,6 +133,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
|||||||
LOG.error("Status code " + statusCode);
|
LOG.error("Status code " + statusCode);
|
||||||
throw new OllamaBaseException(responseBuffer.toString());
|
throw new OllamaBaseException(responseBuffer.toString());
|
||||||
} else {
|
} else {
|
||||||
|
if(wantedToolsForStream != null) {
|
||||||
|
ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
|
||||||
|
}
|
||||||
OllamaChatResult ollamaResult =
|
OllamaChatResult ollamaResult =
|
||||||
new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
|
new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
|
||||||
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
|
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
|
||||||
|
@ -265,7 +265,7 @@ class TestRealAPIs {
|
|||||||
|
|
||||||
OllamaChatRequest requestModel = builder
|
OllamaChatRequest requestModel = builder
|
||||||
.withMessage(OllamaChatMessageRole.USER,
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
"Give me the details of the employee named 'Rahul Kumar'?")
|
"Give me the ID of the employee named 'Rahul Kumar'?")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
@ -288,6 +288,63 @@ class TestRealAPIs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(3)
|
||||||
|
void testChatWithToolsAndStream() {
|
||||||
|
testEndpointReachability();
|
||||||
|
try {
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||||
|
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
||||||
|
.functionName("get-employee-details")
|
||||||
|
.functionDescription("Get employee details from the database")
|
||||||
|
.toolPrompt(
|
||||||
|
Tools.PromptFuncDefinition.builder().type("function").function(
|
||||||
|
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||||
|
.name("get-employee-details")
|
||||||
|
.description("Get employee details from the database")
|
||||||
|
.parameters(
|
||||||
|
Tools.PromptFuncDefinition.Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(
|
||||||
|
new Tools.PropsBuilder()
|
||||||
|
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
|
||||||
|
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
|
||||||
|
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.required(List.of("employee-name"))
|
||||||
|
.build()
|
||||||
|
).build()
|
||||||
|
).build()
|
||||||
|
)
|
||||||
|
.toolFunction(new DBQueryFunction())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ollamaAPI.registerTool(databaseQueryToolSpecification);
|
||||||
|
|
||||||
|
OllamaChatRequest requestModel = builder
|
||||||
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
|
"Give me the ID of the employee named 'Rahul Kumar'?")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
StringBuffer sb = new StringBuffer();
|
||||||
|
|
||||||
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
|
||||||
|
LOG.info(s);
|
||||||
|
String substring = s.substring(sb.toString().length());
|
||||||
|
LOG.info(substring);
|
||||||
|
sb.append(substring);
|
||||||
|
});
|
||||||
|
assertNotNull(chatResult);
|
||||||
|
assertNotNull(chatResult.getResponseModel());
|
||||||
|
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||||
|
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
|
||||||
|
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testChatWithStream() {
|
void testChatWithStream() {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user