diff --git a/docs/docs/apis-generate/chat.md b/docs/docs/apis-generate/chat.md index 9ed9e79..c9c73b7 100644 --- a/docs/docs/apis-generate/chat.md +++ b/docs/docs/apis-generate/chat.md @@ -33,7 +33,7 @@ public class Main { // start conversation with model OllamaChatResult chatResult = ollamaAPI.chat(requestModel); - System.out.println("First answer: " + chatResult.getResponse()); + System.out.println("First answer: " + chatResult.getResponseModel().getMessage().getContent()); // create next userQuestion requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is the second largest city?").build(); @@ -41,7 +41,7 @@ public class Main { // "continue" conversation with model chatResult = ollamaAPI.chat(requestModel); - System.out.println("Second answer: " + chatResult.getResponse()); + System.out.println("Second answer: " + chatResult.getResponseModel().getMessage().getContent()); System.out.println("Chat History: " + chatResult.getChatHistory()); } @@ -205,7 +205,7 @@ public class Main { // start conversation with model OllamaChatResult chatResult = ollamaAPI.chat(requestModel); - System.out.println(chatResult.getResponse()); + System.out.println(chatResult.getResponseModel()); } } @@ -244,7 +244,7 @@ public class Main { new File("/path/to/image"))).build(); OllamaChatResult chatResult = ollamaAPI.chat(requestModel); - System.out.println("First answer: " + chatResult.getResponse()); + System.out.println("First answer: " + chatResult.getResponseModel()); builder.reset(); @@ -254,7 +254,7 @@ public class Main { .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); chatResult = ollamaAPI.chat(requestModel); - System.out.println("Second answer: " + chatResult.getResponse()); + System.out.println("Second answer: " + chatResult.getResponseModel()); } } ``` diff --git a/docs/docs/apis-generate/generate-with-tools.md b/docs/docs/apis-generate/generate-with-tools.md index 3a40150..e0e5794 100644 --- a/docs/docs/apis-generate/generate-with-tools.md +++ b/docs/docs/apis-generate/generate-with-tools.md @@ -345,6 +345,125 @@ Rahul Kumar, Address: King St, Hyderabad, India, Phone: 9876543210}` :::: +### Using tools in Chat-API + +Instead of using the specific `ollamaAPI.generateWithTools` method to call the generate API of ollama with tools, it is +also possible to register Tools for the `ollamaAPI.chat` methods. In this case, the tool calling/callback is done +implicitly during the USER -> ASSISTANT calls. + +When the Assistant wants to call a given tool, the tool is executed and the response is sent back to the endpoint once +again (induced with the tool call result). + +#### Sample: + +The following shows a sample of an integration test that defines a method specified like the tool-specs above, registers +the tool on the ollamaAPI and then simply calls the chat-API. All intermediate tool calling is wrapped inside the api +call. + +```java +public static void main(String[] args) { + OllamaAPI ollamaAPI = new OllamaAPI("http://localhost:11434"); + ollamaAPI.setVerbose(true); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("llama3.2:1b"); + + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get employee details from the database") + .toolPrompt( + Tools.PromptFuncDefinition.builder().type("function").function( + Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name("get-employee-details") + .description("Get employee details from the database") + .parameters( + Tools.PromptFuncDefinition.Parameters.builder() + .type("object") + .properties( + new Tools.PropsBuilder() + .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()) + .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()) + .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()) + .build() + ) + .required(List.of("employee-name")) + .build() + ).build() + ).build() + ) + .toolFunction(new DBQueryFunction()) + .build(); + + ollamaAPI.registerTool(databaseQueryToolSpecification); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Give me the ID of the employee named 'Rahul Kumar'?") + .build(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); +} +``` + +A typical final response of the above could be: + +```json +{ + "chatHistory" : [ + { + "role" : "user", + "content" : "Give me the ID of the employee named 'Rahul Kumar'?", + "images" : null, + "tool_calls" : [ ] + }, { + "role" : "assistant", + "content" : "", + "images" : null, + "tool_calls" : [ { + "function" : { + "name" : "get-employee-details", + "arguments" : { + "employee-name" : "Rahul Kumar" + } + } + } ] + }, { + "role" : "tool", + "content" : "[TOOL_RESULTS]get-employee-details([employee-name]) : Employee Details {ID: b4bf186c-2ee1-44cc-8856-53b8b6a50f85, Name: Rahul Kumar, Address: null, Phone: null}[/TOOL_RESULTS]", + "images" : null, + "tool_calls" : null + }, { + "role" : "assistant", + "content" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.", + "images" : null, + "tool_calls" : null + } ], + "responseModel" : { + "model" : "llama3.2:1b", + "message" : { + "role" : "assistant", + "content" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.", + "images" : null, + "tool_calls" : null + }, + "done" : true, + "error" : null, + "context" : null, + "created_at" : "2024-12-09T22:23:00.4940078Z", + "done_reason" : "stop", + "total_duration" : 2313709900, + "load_duration" : 14494700, + "prompt_eval_duration" : 772000000, + "eval_duration" : 1188000000, + "prompt_eval_count" : 166, + "eval_count" : 41 + }, + "response" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.", + "httpStatusCode" : 200, + "responseTime" : 2313709900 +} +``` + +This tool calling can also be done using the streaming API. + ### Potential Improvements Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 0c89888..90dcd35 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -59,6 +59,10 @@ public class OllamaAPI { */ @Setter private boolean verbose = true; + + @Setter + private int maxChatToolCallRetries = 3; + private BasicAuth basicAuth; private final ToolRegistry toolRegistry = new ToolRegistry(); @@ -767,18 +771,44 @@ public class OllamaAPI { */ public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); - OllamaResult result; + OllamaChatResult result; + + // add all registered tools to Request + request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList())); + if (streamHandler != null) { request.setStream(true); result = requestCaller.call(request, streamHandler); } else { result = requestCaller.callSync(request); } - return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); + + // check if toolCallIsWanted + List toolCalls = result.getResponseModel().getMessage().getToolCalls(); + int toolCallTries = 0; + while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries){ + for (OllamaChatToolCalls toolCall : toolCalls){ + String toolName = toolCall.getFunction().getName(); + ToolFunction toolFunction = toolRegistry.getToolFunction(toolName); + Map arguments = toolCall.getFunction().getArguments(); + Object res = toolFunction.apply(arguments); + request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]")); + } + + if (streamHandler != null) { + result = requestCaller.call(request, streamHandler); + } else { + result = requestCaller.callSync(request); + } + toolCalls = result.getResponseModel().getMessage().getToolCalls(); + toolCallTries++; + } + + return result; } public void registerTool(Tools.ToolSpecification toolSpecification) { - toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition()); + toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); } /** @@ -871,7 +901,7 @@ public class OllamaAPI { try { String methodName = toolFunctionCallSpec.getName(); Map arguments = toolFunctionCallSpec.getArguments(); - ToolFunction function = toolRegistry.getFunction(methodName); + ToolFunction function = toolRegistry.getToolFunction(methodName); if (verbose) { logger.debug("Invoking function {} with arguments {}", methodName, arguments); } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java index 0d6d938..2246488 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java @@ -2,6 +2,7 @@ package io.github.ollama4j.models.chat; import static io.github.ollama4j.utils.Utils.getObjectMapper; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.annotation.JsonSerialize; @@ -32,6 +33,8 @@ public class OllamaChatMessage { @NonNull private String content; + private @JsonProperty("tool_calls") List toolCalls; + @JsonSerialize(using = FileToBase64Serializer.class) private List images; diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java index e6e528d..5d19703 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java @@ -3,6 +3,7 @@ package io.github.ollama4j.models.chat; import java.util.List; import io.github.ollama4j.models.request.OllamaCommonRequest; +import io.github.ollama4j.tools.Tools; import io.github.ollama4j.utils.OllamaRequestBody; import lombok.Getter; @@ -21,6 +22,8 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ private List messages; + private List tools; + public OllamaChatRequest() {} public OllamaChatRequest(String model, List messages) { diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java index c9882d0..9094546 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java @@ -10,6 +10,7 @@ import java.io.IOException; import java.net.URISyntaxException; import java.nio.file.Files; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -38,7 +39,11 @@ public class OllamaChatRequestBuilder { request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); } - public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List images) { + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){ + return withMessage(role,content, Collections.emptyList()); + } + + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls,List images) { List messages = this.request.getMessages(); List binaryImages = images.stream().map(file -> { @@ -50,11 +55,11 @@ public class OllamaChatRequestBuilder { } }).collect(Collectors.toList()); - messages.add(new OllamaChatMessage(role, content, binaryImages)); + messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages)); return this; } - public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) { + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List toolCalls, String... imageUrls) { List messages = this.request.getMessages(); List binaryImages = null; if (imageUrls.length > 0) { @@ -70,7 +75,7 @@ public class OllamaChatRequestBuilder { } } - messages.add(new OllamaChatMessage(role, content, binaryImages)); + messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages)); return this; } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java index b9616f3..bf7eaea 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java @@ -2,28 +2,54 @@ package io.github.ollama4j.models.chat; import java.util.List; -import io.github.ollama4j.models.response.OllamaResult; +import com.fasterxml.jackson.core.JsonProcessingException; +import lombok.Getter; + +import static io.github.ollama4j.utils.Utils.getObjectMapper; /** * Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the * {@link OllamaChatMessageRole#ASSISTANT} role. */ -public class OllamaChatResult extends OllamaResult { +@Getter +public class OllamaChatResult { + private List chatHistory; - public OllamaChatResult(String response, long responseTime, int httpStatusCode, List chatHistory) { - super(response, responseTime, httpStatusCode); + private OllamaChatResponseModel responseModel; + + public OllamaChatResult(OllamaChatResponseModel responseModel, List chatHistory) { this.chatHistory = chatHistory; - appendAnswerToChatHistory(response); + this.responseModel = responseModel; + appendAnswerToChatHistory(responseModel); } - public List getChatHistory() { - return chatHistory; + private void appendAnswerToChatHistory(OllamaChatResponseModel response) { + this.chatHistory.add(response.getMessage()); } - private void appendAnswerToChatHistory(String answer) { - OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer); - this.chatHistory.add(assistantMessage); + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Deprecated + public String getResponse(){ + return responseModel != null ? responseModel.getMessage().getContent() : ""; + } + + @Deprecated + public int getHttpStatusCode(){ + return 200; + } + + @Deprecated + public long getResponseTime(){ + return responseModel != null ? responseModel.getTotalDuration() : 0L; } } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatToolCalls.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatToolCalls.java new file mode 100644 index 0000000..de1a081 --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatToolCalls.java @@ -0,0 +1,16 @@ +package io.github.ollama4j.models.chat; + +import io.github.ollama4j.tools.OllamaToolCallsFunction; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class OllamaChatToolCalls { + + private OllamaToolCallsFunction function; + + +} diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java index e3d3fc1..57c9ee3 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java @@ -3,17 +3,24 @@ package io.github.ollama4j.models.request; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import io.github.ollama4j.exceptions.OllamaBaseException; -import io.github.ollama4j.models.chat.OllamaChatMessage; -import io.github.ollama4j.models.response.OllamaResult; -import io.github.ollama4j.models.chat.OllamaChatResponseModel; -import io.github.ollama4j.models.chat.OllamaChatStreamObserver; +import io.github.ollama4j.models.chat.*; +import io.github.ollama4j.models.response.OllamaErrorResponse; import io.github.ollama4j.models.generate.OllamaStreamHandler; -import io.github.ollama4j.utils.OllamaRequestBody; +import io.github.ollama4j.tools.Tools; import io.github.ollama4j.utils.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.util.List; /** * Specialization class for requests @@ -64,9 +71,75 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { } } - public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) + public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { streamObserver = new OllamaChatStreamObserver(streamHandler); - return super.callSync(body); + return callSync(body); + } + + public OllamaChatResult callSync(OllamaChatRequest body) throws OllamaBaseException, IOException, InterruptedException { + // Create Request + HttpClient httpClient = HttpClient.newHttpClient(); + URI uri = URI.create(getHost() + getEndpointSuffix()); + HttpRequest.Builder requestBuilder = + getRequestBuilderDefault(uri) + .POST( + body.getBodyPublisher()); + HttpRequest request = requestBuilder.build(); + if (isVerbose()) LOG.info("Asking model: " + body.toString()); + HttpResponse response = + httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + + int statusCode = response.statusCode(); + InputStream responseBodyStream = response.body(); + StringBuilder responseBuffer = new StringBuilder(); + OllamaChatResponseModel ollamaChatResponseModel = null; + List wantedToolsForStream = null; + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + + String line; + while ((line = reader.readLine()) != null) { + if (statusCode == 404) { + LOG.warn("Status code: 404 (Not Found)"); + OllamaErrorResponse ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else if (statusCode == 401) { + LOG.warn("Status code: 401 (Unauthorized)"); + OllamaErrorResponse ollamaResponseModel = + Utils.getObjectMapper() + .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else if (statusCode == 400) { + LOG.warn("Status code: 400 (Bad Request)"); + OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, + OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else { + boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); + ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); + if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){ + wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls(); + } + if (finished && body.stream) { + ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString()); + break; + } + } + } + } + if (statusCode != 200) { + LOG.error("Status code " + statusCode); + throw new OllamaBaseException(responseBuffer.toString()); + } else { + if(wantedToolsForStream != null) { + ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream); + } + OllamaChatResult ollamaResult = + new OllamaChatResult(ollamaChatResponseModel,body.getMessages()); + if (isVerbose()) LOG.info("Model response: " + ollamaResult); + return ollamaResult; + } } } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java index 8529c18..e9d0e0d 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java @@ -6,6 +6,7 @@ import io.github.ollama4j.models.response.OllamaErrorResponse; import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.utils.OllamaRequestBody; import io.github.ollama4j.utils.Utils; +import lombok.Getter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -24,14 +25,15 @@ import java.util.Base64; /** * Abstract helperclass to call the ollama api server. */ +@Getter public abstract class OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class); - private String host; - private BasicAuth basicAuth; - private long requestTimeoutSeconds; - private boolean verbose; + private final String host; + private final BasicAuth basicAuth; + private final long requestTimeoutSeconds; + private final boolean verbose; public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { this.host = host; @@ -45,80 +47,13 @@ public abstract class OllamaEndpointCaller { protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); - /** - * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response. - * - * @param body POST body payload - * @return result answer given by the assistant - * @throws OllamaBaseException any response code than 200 has been returned - * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network issues happen - */ - public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException { - // Create Request - long startTime = System.currentTimeMillis(); - HttpClient httpClient = HttpClient.newHttpClient(); - URI uri = URI.create(this.host + getEndpointSuffix()); - HttpRequest.Builder requestBuilder = - getRequestBuilderDefault(uri) - .POST( - body.getBodyPublisher()); - HttpRequest request = requestBuilder.build(); - if (this.verbose) LOG.info("Asking model: " + body.toString()); - HttpResponse response = - httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); - - int statusCode = response.statusCode(); - InputStream responseBodyStream = response.body(); - StringBuilder responseBuffer = new StringBuilder(); - try (BufferedReader reader = - new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { - String line; - while ((line = reader.readLine()) != null) { - if (statusCode == 404) { - LOG.warn("Status code: 404 (Not Found)"); - OllamaErrorResponse ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); - responseBuffer.append(ollamaResponseModel.getError()); - } else if (statusCode == 401) { - LOG.warn("Status code: 401 (Unauthorized)"); - OllamaErrorResponse ollamaResponseModel = - Utils.getObjectMapper() - .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class); - responseBuffer.append(ollamaResponseModel.getError()); - } else if (statusCode == 400) { - LOG.warn("Status code: 400 (Bad Request)"); - OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, - OllamaErrorResponse.class); - responseBuffer.append(ollamaResponseModel.getError()); - } else { - boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); - if (finished) { - break; - } - } - } - } - - if (statusCode != 200) { - LOG.error("Status code " + statusCode); - throw new OllamaBaseException(responseBuffer.toString()); - } else { - long endTime = System.currentTimeMillis(); - OllamaResult ollamaResult = - new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); - if (verbose) LOG.info("Model response: " + ollamaResult); - return ollamaResult; - } - } - /** * Get default request builder. * * @param uri URI to get a HttpRequest.Builder * @return HttpRequest.Builder */ - private HttpRequest.Builder getRequestBuilderDefault(URI uri) { + protected HttpRequest.Builder getRequestBuilderDefault(URI uri) { HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri) .header("Content-Type", "application/json") @@ -134,7 +69,7 @@ public abstract class OllamaEndpointCaller { * * @return basic authentication header value (encoded credentials) */ - private String getBasicAuthHeaderValue() { + protected String getBasicAuthHeaderValue() { String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword(); return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); } @@ -144,7 +79,7 @@ public abstract class OllamaEndpointCaller { * * @return true when Basic Auth credentials set */ - private boolean isBasicAuthCredentialsSet() { + protected boolean isBasicAuthCredentialsSet() { return this.basicAuth != null; } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java index f4afb2c..00b2b12 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java @@ -2,6 +2,7 @@ package io.github.ollama4j.models.request; import com.fasterxml.jackson.core.JsonProcessingException; import io.github.ollama4j.exceptions.OllamaBaseException; +import io.github.ollama4j.models.response.OllamaErrorResponse; import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.generate.OllamaGenerateResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver; @@ -11,7 +12,15 @@ import io.github.ollama4j.utils.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { @@ -46,6 +55,73 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { streamObserver = new OllamaGenerateStreamObserver(streamHandler); - return super.callSync(body); + return callSync(body); + } + + /** + * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response. + * + * @param body POST body payload + * @return result answer given by the assistant + * @throws OllamaBaseException any response code than 200 has been returned + * @throws IOException in case the responseStream can not be read + * @throws InterruptedException in case the server is not reachable or network issues happen + */ + public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException { + // Create Request + long startTime = System.currentTimeMillis(); + HttpClient httpClient = HttpClient.newHttpClient(); + URI uri = URI.create(getHost() + getEndpointSuffix()); + HttpRequest.Builder requestBuilder = + getRequestBuilderDefault(uri) + .POST( + body.getBodyPublisher()); + HttpRequest request = requestBuilder.build(); + if (isVerbose()) LOG.info("Asking model: " + body.toString()); + HttpResponse response = + httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + + int statusCode = response.statusCode(); + InputStream responseBodyStream = response.body(); + StringBuilder responseBuffer = new StringBuilder(); + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + if (statusCode == 404) { + LOG.warn("Status code: 404 (Not Found)"); + OllamaErrorResponse ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else if (statusCode == 401) { + LOG.warn("Status code: 401 (Unauthorized)"); + OllamaErrorResponse ollamaResponseModel = + Utils.getObjectMapper() + .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else if (statusCode == 400) { + LOG.warn("Status code: 400 (Bad Request)"); + OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, + OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else { + boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); + if (finished) { + break; + } + } + } + } + + if (statusCode != 200) { + LOG.error("Status code " + statusCode); + throw new OllamaBaseException(responseBuffer.toString()); + } else { + long endTime = System.currentTimeMillis(); + OllamaResult ollamaResult = + new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); + if (isVerbose()) LOG.info("Model response: " + ollamaResult); + return ollamaResult; + } } } diff --git a/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java b/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java new file mode 100644 index 0000000..4be7194 --- /dev/null +++ b/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java @@ -0,0 +1,16 @@ +package io.github.ollama4j.tools; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.Map; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class OllamaToolCallsFunction +{ + private String name; + private Map arguments; +} diff --git a/src/main/java/io/github/ollama4j/tools/ToolRegistry.java b/src/main/java/io/github/ollama4j/tools/ToolRegistry.java index 2ead13a..5ab8be3 100644 --- a/src/main/java/io/github/ollama4j/tools/ToolRegistry.java +++ b/src/main/java/io/github/ollama4j/tools/ToolRegistry.java @@ -1,16 +1,22 @@ package io.github.ollama4j.tools; +import java.util.Collection; import java.util.HashMap; import java.util.Map; public class ToolRegistry { - private final Map functionMap = new HashMap<>(); + private final Map tools = new HashMap<>(); - public ToolFunction getFunction(String name) { - return functionMap.get(name); + public ToolFunction getToolFunction(String name) { + final Tools.ToolSpecification toolSpecification = tools.get(name); + return toolSpecification !=null ? toolSpecification.getToolFunction() : null ; } - public void addFunction(String name, ToolFunction function) { - functionMap.put(name, function); + public void addTool (String name, Tools.ToolSpecification specification) { + tools.put(name, specification); + } + + public Collection getRegisteredSpecs(){ + return tools.values(); } } diff --git a/src/main/java/io/github/ollama4j/tools/Tools.java b/src/main/java/io/github/ollama4j/tools/Tools.java index 986302f..eb8dcca 100644 --- a/src/main/java/io/github/ollama4j/tools/Tools.java +++ b/src/main/java/io/github/ollama4j/tools/Tools.java @@ -6,8 +6,10 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import io.github.ollama4j.utils.Utils; +import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; +import lombok.NoArgsConstructor; import java.util.ArrayList; import java.util.HashMap; @@ -20,17 +22,23 @@ public class Tools { public static class ToolSpecification { private String functionName; private String functionDescription; - private Map properties; - private ToolFunction toolDefinition; + private PromptFuncDefinition toolPrompt; + private ToolFunction toolFunction; } @Data @JsonIgnoreProperties(ignoreUnknown = true) + @Builder + @NoArgsConstructor + @AllArgsConstructor public static class PromptFuncDefinition { private String type; private PromptFuncSpec function; @Data + @Builder + @NoArgsConstructor + @AllArgsConstructor public static class PromptFuncSpec { private String name; private String description; @@ -38,6 +46,9 @@ public class Tools { } @Data + @Builder + @NoArgsConstructor + @AllArgsConstructor public static class Parameters { private String type; private Map properties; @@ -46,6 +57,8 @@ public class Tools { @Data @Builder + @NoArgsConstructor + @AllArgsConstructor public static class Property { private String type; private String description; @@ -94,10 +107,10 @@ public class Tools { PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters(); parameters.setType("object"); - parameters.setProperties(spec.getProperties()); + parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties()); List requiredValues = new ArrayList<>(); - for (Map.Entry p : spec.getProperties().entrySet()) { + for (Map.Entry p : spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) { if (p.getValue().isRequired()) { requiredValues.add(p.getKey()); } diff --git a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java index 0a1da61..668a5dc 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java @@ -2,14 +2,13 @@ package io.github.ollama4j.integrationtests; import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.exceptions.OllamaBaseException; +import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.response.ModelDetail; -import io.github.ollama4j.models.chat.OllamaChatRequest; import io.github.ollama4j.models.response.OllamaResult; -import io.github.ollama4j.models.chat.OllamaChatMessageRole; -import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; -import io.github.ollama4j.models.chat.OllamaChatResult; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; +import io.github.ollama4j.tools.ToolFunction; +import io.github.ollama4j.tools.Tools; import io.github.ollama4j.utils.OptionsBuilder; import lombok.Data; import org.junit.jupiter.api.BeforeEach; @@ -24,9 +23,7 @@ import java.io.InputStream; import java.net.ConnectException; import java.net.URISyntaxException; import java.net.http.HttpConnectTimeoutException; -import java.util.List; -import java.util.Objects; -import java.util.Properties; +import java.util.*; import static org.junit.jupiter.api.Assertions.*; @@ -47,6 +44,7 @@ class TestRealAPIs { config = new Config(); ollamaAPI = new OllamaAPI(config.getOllamaURL()); ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds()); + ollamaAPI.setVerbose(true); } @Test @@ -196,7 +194,9 @@ class TestRealAPIs { OllamaChatResult chatResult = ollamaAPI.chat(requestModel); assertNotNull(chatResult); - assertFalse(chatResult.getResponse().isBlank()); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank()); assertEquals(4, chatResult.getChatHistory().size()); } catch (IOException | OllamaBaseException | InterruptedException e) { fail(e); @@ -217,14 +217,134 @@ class TestRealAPIs { OllamaChatResult chatResult = ollamaAPI.chat(requestModel); assertNotNull(chatResult); - assertFalse(chatResult.getResponse().isBlank()); - assertTrue(chatResult.getResponse().startsWith("NI")); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank()); + assertTrue(chatResult.getResponseModel().getMessage().getContent().startsWith("NI")); assertEquals(3, chatResult.getChatHistory().size()); } catch (IOException | OllamaBaseException | InterruptedException e) { fail(e); } } + @Test + @Order(3) + void testChatWithTools() { + testEndpointReachability(); + try { + ollamaAPI.setVerbose(true); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); + + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get employee details from the database") + .toolPrompt( + Tools.PromptFuncDefinition.builder().type("function").function( + Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name("get-employee-details") + .description("Get employee details from the database") + .parameters( + Tools.PromptFuncDefinition.Parameters.builder() + .type("object") + .properties( + new Tools.PropsBuilder() + .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()) + .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()) + .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()) + .build() + ) + .required(List.of("employee-name")) + .build() + ).build() + ).build() + ) + .toolFunction(new DBQueryFunction()) + .build(); + + ollamaAPI.registerTool(databaseQueryToolSpecification); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Give me the ID of the employee named 'Rahul Kumar'?") + .build(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName()); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); + assertEquals(1, toolCalls.size()); + assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName()); + assertEquals(1, toolCalls.get(0).getFunction().getArguments().size()); + Object employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name"); + assertNotNull(employeeName); + assertEquals("Rahul Kumar",employeeName); + assertTrue(chatResult.getChatHistory().size()>2); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } + } + + @Test + @Order(3) + void testChatWithToolsAndStream() { + testEndpointReachability(); + try { + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get employee details from the database") + .toolPrompt( + Tools.PromptFuncDefinition.builder().type("function").function( + Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name("get-employee-details") + .description("Get employee details from the database") + .parameters( + Tools.PromptFuncDefinition.Parameters.builder() + .type("object") + .properties( + new Tools.PropsBuilder() + .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()) + .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()) + .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()) + .build() + ) + .required(List.of("employee-name")) + .build() + ).build() + ).build() + ) + .toolFunction(new DBQueryFunction()) + .build(); + + ollamaAPI.registerTool(databaseQueryToolSpecification); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Give me the ID of the employee named 'Rahul Kumar'?") + .build(); + + StringBuffer sb = new StringBuffer(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length()); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertNotNull(chatResult.getResponseModel().getMessage().getContent()); + assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } + } + @Test @Order(3) void testChatWithStream() { @@ -244,7 +364,10 @@ class TestRealAPIs { sb.append(substring); }); assertNotNull(chatResult); - assertEquals(sb.toString().trim(), chatResult.getResponse().trim()); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertNotNull(chatResult.getResponseModel().getMessage().getContent()); + assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); } catch (IOException | OllamaBaseException | InterruptedException e) { fail(e); } @@ -258,12 +381,12 @@ class TestRealAPIs { OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel()); OllamaChatRequest requestModel = - builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", + builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(), List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); OllamaChatResult chatResult = ollamaAPI.chat(requestModel); assertNotNull(chatResult); - assertNotNull(chatResult.getResponse()); + assertNotNull(chatResult.getResponseModel()); builder.reset(); @@ -273,7 +396,7 @@ class TestRealAPIs { chatResult = ollamaAPI.chat(requestModel); assertNotNull(chatResult); - assertNotNull(chatResult.getResponse()); + assertNotNull(chatResult.getResponseModel()); } catch (IOException | OllamaBaseException | InterruptedException e) { @@ -287,7 +410,7 @@ class TestRealAPIs { testEndpointReachability(); try { OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel()); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(), "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") .build(); @@ -380,6 +503,14 @@ class TestRealAPIs { } } +class DBQueryFunction implements ToolFunction { + @Override + public Object apply(Map arguments) { + // perform DB operations here + return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone")); + } +} + @Data class Config { private String ollamaURL; @@ -404,4 +535,6 @@ class Config { throw new RuntimeException("Error loading properties", e); } } + + } diff --git a/src/test/java/io/github/ollama4j/unittests/jackson/TestChatRequestSerialization.java b/src/test/java/io/github/ollama4j/unittests/jackson/TestChatRequestSerialization.java index 2ce210c..db33889 100644 --- a/src/test/java/io/github/ollama4j/unittests/jackson/TestChatRequestSerialization.java +++ b/src/test/java/io/github/ollama4j/unittests/jackson/TestChatRequestSerialization.java @@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrowsExactly; import java.io.File; +import java.util.Collections; import java.util.List; import io.github.ollama4j.models.chat.OllamaChatRequest; @@ -42,7 +43,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest