diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index bb7c1f8..f595090 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -602,7 +602,7 @@ public class OllamaAPI { OllamaResult result = generate(model, prompt, raw, options, null); toolResult.setModelResult(result); - String toolsResponse = result.getResponse(); + String toolsResponse = result.getContent(); if (toolsResponse.contains("[TOOL_CALLS]")) { toolsResponse = toolsResponse.replace("[TOOL_CALLS]", ""); } @@ -768,6 +768,10 @@ public class OllamaAPI { public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaResult result; + + // add all registered tools to Request + request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList())); + if (streamHandler != null) { request.setStream(true); result = requestCaller.call(request, streamHandler); @@ -775,10 +779,7 @@ public class OllamaAPI { result = requestCaller.callSync(request); } - // add all registered tools to Request - request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList())); - - return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); + return new OllamaChatResult(result.getContent(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); } public void registerTool(Tools.ToolSpecification toolSpecification) { diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java index 0d6d938..2246488 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java @@ -2,6 +2,7 @@ package io.github.ollama4j.models.chat; import static io.github.ollama4j.utils.Utils.getObjectMapper; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.annotation.JsonSerialize; @@ -32,6 +33,8 @@ public class OllamaChatMessage { @NonNull private String content; + private @JsonProperty("tool_calls") List toolCalls; + @JsonSerialize(using = FileToBase64Serializer.class) private List images; diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java index c9882d0..3546ba8 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java @@ -38,7 +38,7 @@ public class OllamaChatRequestBuilder { request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); } - public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List images) { + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls,List images) { List messages = this.request.getMessages(); List binaryImages = images.stream().map(file -> { @@ -50,11 +50,11 @@ public class OllamaChatRequestBuilder { } }).collect(Collectors.toList()); - messages.add(new OllamaChatMessage(role, content, binaryImages)); + messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages)); return this; } - public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) { + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List toolCalls, String... imageUrls) { List messages = this.request.getMessages(); List binaryImages = null; if (imageUrls.length > 0) { @@ -70,7 +70,7 @@ public class OllamaChatRequestBuilder { } } - messages.add(new OllamaChatMessage(role, content, binaryImages)); + messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages)); return this; } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatToolCalls.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatToolCalls.java new file mode 100644 index 0000000..de1a081 --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatToolCalls.java @@ -0,0 +1,16 @@ +package io.github.ollama4j.models.chat; + +import io.github.ollama4j.tools.OllamaToolCallsFunction; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class OllamaChatToolCalls { + + private OllamaToolCallsFunction function; + + +} diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java index beb01ec..8465cb6 100644 --- a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java +++ b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java @@ -17,7 +17,7 @@ public class OllamaResult { * * @return String completion/response text */ - private final String response; + private final String content; /** * -- GETTER -- @@ -35,8 +35,8 @@ public class OllamaResult { */ private long responseTime = 0; - public OllamaResult(String response, long responseTime, int httpStatusCode) { - this.response = response; + public OllamaResult(String content, long responseTime, int httpStatusCode) { + this.content = content; this.responseTime = responseTime; this.httpStatusCode = httpStatusCode; } diff --git a/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java b/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java new file mode 100644 index 0000000..dfa4d84 --- /dev/null +++ b/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java @@ -0,0 +1,16 @@ +package io.github.ollama4j.tools; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.Map; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class OllamaToolCallsFunction +{ + private String name; + private Map arguments; +} diff --git a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java index 1175b18..702d0a2 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java @@ -10,6 +10,8 @@ import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; import io.github.ollama4j.models.chat.OllamaChatResult; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; +import io.github.ollama4j.tools.ToolFunction; +import io.github.ollama4j.tools.Tools; import io.github.ollama4j.utils.OptionsBuilder; import lombok.Data; import org.junit.jupiter.api.BeforeEach; @@ -24,9 +26,7 @@ import java.io.InputStream; import java.net.ConnectException; import java.net.URISyntaxException; import java.net.http.HttpConnectTimeoutException; -import java.util.List; -import java.util.Objects; -import java.util.Properties; +import java.util.*; import static org.junit.jupiter.api.Assertions.*; @@ -230,18 +230,47 @@ class TestRealAPIs { void testChatWithTools() { testEndpointReachability(); try { + ollamaAPI.setVerbose(true); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, - "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!") + + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get employee details from the database") + .toolPrompt( + Tools.PromptFuncDefinition.builder().type("function").function( + Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name("get-employee-details") + .description("Get employee details from the database") + .parameters( + Tools.PromptFuncDefinition.Parameters.builder() + .type("object") + .properties( + new Tools.PropsBuilder() + .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()) + .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()) + .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()) + .build() + ) + .required(List.of("employee-name")) + .build() + ).build() + ).build() + ) + .toolFunction(new DBQueryFunction()) + .build(); + + ollamaAPI.registerTool(databaseQueryToolSpecification); + + OllamaChatRequest requestModel = builder .withMessage(OllamaChatMessageRole.USER, - "What is the capital of France? And what's France's connection with Mona Lisa?") + "Give me the details of the employee named 'Rahul Kumar'?") .build(); OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + System.err.println("Response: " + chatResult); assertNotNull(chatResult); assertFalse(chatResult.getResponse().isBlank()); - assertTrue(chatResult.getResponse().startsWith("NI")); - assertEquals(3, chatResult.getChatHistory().size()); + assertEquals(2, chatResult.getChatHistory().size()); } catch (IOException | OllamaBaseException | InterruptedException e) { fail(e); } @@ -402,6 +431,14 @@ class TestRealAPIs { } } +class DBQueryFunction implements ToolFunction { + @Override + public Object apply(Map arguments) { + // perform DB operations here + return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString()); + } +} + @Data class Config { private String ollamaURL; @@ -426,4 +463,6 @@ class Config { throw new RuntimeException("Error loading properties", e); } } + + }