diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 0c89888..3421b1c 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -774,11 +774,15 @@ public class OllamaAPI { } else { result = requestCaller.callSync(request); } + + //add registered Tools to Request + + return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); } public void registerTool(Tools.ToolSpecification toolSpecification) { - toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition()); + toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); } /** @@ -871,7 +875,7 @@ public class OllamaAPI { try { String methodName = toolFunctionCallSpec.getName(); Map arguments = toolFunctionCallSpec.getArguments(); - ToolFunction function = toolRegistry.getFunction(methodName); + ToolFunction function = toolRegistry.getToolFunction(methodName); if (verbose) { logger.debug("Invoking function {} with arguments {}", methodName, arguments); } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java index e6e528d..5d19703 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java @@ -3,6 +3,7 @@ package io.github.ollama4j.models.chat; import java.util.List; import io.github.ollama4j.models.request.OllamaCommonRequest; +import io.github.ollama4j.tools.Tools; import io.github.ollama4j.utils.OllamaRequestBody; import lombok.Getter; @@ -21,6 +22,8 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ private List messages; + private List tools; + public OllamaChatRequest() {} public OllamaChatRequest(String model, List messages) { diff --git a/src/main/java/io/github/ollama4j/tools/ToolRegistry.java b/src/main/java/io/github/ollama4j/tools/ToolRegistry.java index 2ead13a..bb504c6 100644 --- a/src/main/java/io/github/ollama4j/tools/ToolRegistry.java +++ b/src/main/java/io/github/ollama4j/tools/ToolRegistry.java @@ -4,13 +4,14 @@ import java.util.HashMap; import java.util.Map; public class ToolRegistry { - private final Map functionMap = new HashMap<>(); + private final Map tools = new HashMap<>(); - public ToolFunction getFunction(String name) { - return functionMap.get(name); + public ToolFunction getToolFunction(String name) { + final Tools.ToolSpecification toolSpecification = tools.get(name); + return toolSpecification !=null ? toolSpecification.getToolFunction() : null ; } - public void addFunction(String name, ToolFunction function) { - functionMap.put(name, function); + public void addTool (String name, Tools.ToolSpecification specification) { + tools.put(name, specification); } } diff --git a/src/main/java/io/github/ollama4j/tools/Tools.java b/src/main/java/io/github/ollama4j/tools/Tools.java index 986302f..eb8dcca 100644 --- a/src/main/java/io/github/ollama4j/tools/Tools.java +++ b/src/main/java/io/github/ollama4j/tools/Tools.java @@ -6,8 +6,10 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import io.github.ollama4j.utils.Utils; +import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; +import lombok.NoArgsConstructor; import java.util.ArrayList; import java.util.HashMap; @@ -20,17 +22,23 @@ public class Tools { public static class ToolSpecification { private String functionName; private String functionDescription; - private Map properties; - private ToolFunction toolDefinition; + private PromptFuncDefinition toolPrompt; + private ToolFunction toolFunction; } @Data @JsonIgnoreProperties(ignoreUnknown = true) + @Builder + @NoArgsConstructor + @AllArgsConstructor public static class PromptFuncDefinition { private String type; private PromptFuncSpec function; @Data + @Builder + @NoArgsConstructor + @AllArgsConstructor public static class PromptFuncSpec { private String name; private String description; @@ -38,6 +46,9 @@ public class Tools { } @Data + @Builder + @NoArgsConstructor + @AllArgsConstructor public static class Parameters { private String type; private Map properties; @@ -46,6 +57,8 @@ public class Tools { @Data @Builder + @NoArgsConstructor + @AllArgsConstructor public static class Property { private String type; private String description; @@ -94,10 +107,10 @@ public class Tools { PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters(); parameters.setType("object"); - parameters.setProperties(spec.getProperties()); + parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties()); List requiredValues = new ArrayList<>(); - for (Map.Entry p : spec.getProperties().entrySet()) { + for (Map.Entry p : spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) { if (p.getValue().isRequired()) { requiredValues.add(p.getKey()); }