From 81689be194c236038c47da5007d72a9ced3e4281 Mon Sep 17 00:00:00 2001 From: "koujalgi.amith@gmail.com" Date: Sun, 14 Jul 2024 11:23:36 +0530 Subject: [PATCH] Refactored tools API Signed-off-by: koujalgi.amith@gmail.com --- .../docs/apis-generate/generate-with-tools.md | 214 +++++++++++++----- .../ollama4j/core/OllamaAPI.java | 72 +++--- .../exceptions/ToolInvocationException.java | 8 + .../exceptions/ToolNotFoundException.java | 8 + .../core/tools/OllamaToolsResult.java | 21 +- ...DynamicFunction.java => ToolFunction.java} | 2 +- ...ToolDef.java => ToolFunctionCallSpec.java} | 4 +- .../ollama4j/core/tools/ToolRegistry.java | 7 +- .../tools/{MistralTools.java => Tools.java} | 40 +--- 9 files changed, 241 insertions(+), 135 deletions(-) create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/exceptions/ToolInvocationException.java create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/exceptions/ToolNotFoundException.java rename src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/{DynamicFunction.java => ToolFunction.java} (80%) rename src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/{ToolDef.java => ToolFunctionCallSpec.java} (88%) rename src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/{MistralTools.java => Tools.java} (69%) diff --git a/docs/docs/apis-generate/generate-with-tools.md b/docs/docs/apis-generate/generate-with-tools.md index 6b7cca6..86c2d83 100644 --- a/docs/docs/apis-generate/generate-with-tools.md +++ b/docs/docs/apis-generate/generate-with-tools.md @@ -29,8 +29,8 @@ You could do that with ease with the `function calling` capabilities of the mode ### Create Functions -This function takes the arguments `location` and `fuelType` and performs an operation with these arguments and returns a -value. +This function takes the arguments `location` and `fuelType` and performs an operation with these arguments and returns +fuel price value. ```java public static String getCurrentFuelPrice(Map arguments) { @@ -40,8 +40,8 @@ public static String getCurrentFuelPrice(Map arguments) { } ``` -This function takes the argument `city` and performs an operation with the argument and returns a -value. +This function takes the argument `city` and performs an operation with the argument and returns the weather for a +location. ```java public static String getCurrentWeather(Map arguments) { @@ -50,6 +50,19 @@ public static String getCurrentWeather(Map arguments) { } ``` +This function takes the argument `employee-name` and performs an operation with the argument and returns employee +details. + +```java +class DBQueryFunction implements ToolFunction { + @Override + public Object apply(Map arguments) { + // perform DB operations here + return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString()); + } +} +``` + ### Define Tool Specifications Lets define a sample tool specification called **Fuel Price Tool** for getting the current fuel price. @@ -58,13 +71,13 @@ Lets define a sample tool specification called **Fuel Price Tool** for getting t - Associate the `getCurrentFuelPrice` function you defined earlier with `SampleTools::getCurrentFuelPrice`. ```java -MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder() +Tools.ToolSpecification fuelPriceToolSpecification = Tools.ToolSpecification.builder() .functionName("current-fuel-price") - .functionDesc("Get current fuel price") - .props( - new MistralTools.PropsBuilder() - .withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) - .withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build()) + .functionDescription("Get current fuel price") + .properties( + new Tools.PropsBuilder() + .withProperty("location", Tools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) + .withProperty("fuelType", Tools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build()) .build() ) .toolDefinition(SampleTools::getCurrentFuelPrice) @@ -77,18 +90,38 @@ Lets also define a sample tool specification called **Weather Tool** for getting - Associate the `getCurrentWeather` function you defined earlier with `SampleTools::getCurrentWeather`. ```java -MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder() +Tools.ToolSpecification weatherToolSpecification = Tools.ToolSpecification.builder() .functionName("current-weather") - .functionDesc("Get current weather") - .props( - new MistralTools.PropsBuilder() - .withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) + .functionDescription("Get current weather") + .properties( + new Tools.PropsBuilder() + .withProperty("city", Tools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) .build() ) .toolDefinition(SampleTools::getCurrentWeather) .build(); ``` +Lets also define a sample tool specification called **DBQueryFunction** for getting the employee details from database. + +- Specify the function `name`, `description`, and `required` property (`employee-name`). +- Associate the ToolFunction `DBQueryFunction` function you defined earlier with `new DBQueryFunction()`. + +```java +Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get employee details from the database") + .properties( + new Tools.PropsBuilder() + .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()) + .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()) + .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()) + .build() + ) + .toolDefinition(new DBQueryFunction()) + .build(); +``` + ### Register the Tools Register the defined tools (`fuel price` and `weather`) with the OllamaAPI. @@ -103,14 +136,14 @@ ollamaAPI.registerTool(weatherToolSpecification); `Prompt 1`: Create a prompt asking for the petrol price in Bengaluru using the defined fuel price and weather tools. ```shell -String prompt1 = new MistralTools.PromptBuilder() - .withToolSpecification(fuelPriceToolSpecification) - .withToolSpecification(weatherToolSpecification) - .withPrompt("What is the petrol price in Bengaluru?") - .build(); -OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt1, false, new OptionsBuilder().build()); -for (Map.Entry r : toolsResult.getToolResults().entrySet()) { - System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString()); +String prompt1 = new Tools.PromptBuilder() + .withToolSpecification(fuelPriceToolSpecification) + .withToolSpecification(weatherToolSpecification) + .withPrompt("What is the petrol price in Bengaluru?") + .build(); +OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt1, new OptionsBuilder().build()); +for (OllamaToolsResult.ToolResult r : toolsResult.getToolResults()) { + System.out.printf("[Result of executing tool '%s']: %s%n", r.getFunctionName(), r.getResult().toString()); } ``` @@ -120,21 +153,21 @@ You will get a response similar to: ::::tip[LLM Response] -[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L +[Result of executing tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L :::: `Prompt 2`: Create a prompt asking for the current weather in Bengaluru using the same tools. ```shell -String prompt2 = new MistralTools.PromptBuilder() - .withToolSpecification(fuelPriceToolSpecification) - .withToolSpecification(weatherToolSpecification) - .withPrompt("What is the current weather in Bengaluru?") - .build(); -OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt2, false, new OptionsBuilder().build()); -for (Map.Entry r : toolsResult.getToolResults().entrySet()) { - System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString()); +String prompt2 = new Tools.PromptBuilder() + .withToolSpecification(fuelPriceToolSpecification) + .withToolSpecification(weatherToolSpecification) + .withPrompt("What is the current weather in Bengaluru?") + .build(); +OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt2, new OptionsBuilder().build()); +for (OllamaToolsResult.ToolResult r : toolsResult.getToolResults()) { + System.out.printf("[Result of executing tool '%s']: %s%n", r.getFunctionName(), r.getResult().toString()); } ``` @@ -144,25 +177,53 @@ You will get a response similar to: ::::tip[LLM Response] -[Response from tool 'current-weather']: Currently Bengaluru's weather is nice +[Result of executing tool 'current-weather']: Currently Bengaluru's weather is nice. + +:::: + +`Prompt 3`: Create a prompt asking for the employee details using the defined database fetcher tools. + +```shell +String prompt3 = new Tools.PromptBuilder() + .withToolSpecification(fuelPriceToolSpecification) + .withToolSpecification(weatherToolSpecification) + .withToolSpecification(databaseQueryToolSpecification) + .withPrompt("Give me the details of the employee named 'Rahul Kumar'?") + .build(); +OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt3, new OptionsBuilder().build()); +for (OllamaToolsResult.ToolResult r : toolsResult.getToolResults()) { + System.out.printf("[Result of executing tool '%s']: %s%n", r.getFunctionName(), r.getResult().toString()); +} +``` + +Again, fire away your question to the model. + +You will get a response similar to: + +::::tip[LLM Response] + +[Result of executing tool 'get-employee-details']: Employee Details `{ID: 6bad82e6-b1a1-458f-a139-e3b646e092b1, Name: +Rahul Kumar, Address: King St, Hyderabad, India, Phone: 9876543210}` + :::: ### Full Example ```java - import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; -import io.github.amithkoujalgi.ollama4j.core.tools.ToolDef; -import io.github.amithkoujalgi.ollama4j.core.tools.MistralTools; +import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolInvocationException; import io.github.amithkoujalgi.ollama4j.core.tools.OllamaToolsResult; +import io.github.amithkoujalgi.ollama4j.core.tools.ToolFunction; +import io.github.amithkoujalgi.ollama4j.core.tools.Tools; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import java.io.IOException; import java.util.Arrays; import java.util.Map; +import java.util.UUID; -public class FunctionCallingWithMistral { +public class FunctionCallingWithMistralExample { public static void main(String[] args) throws Exception { String host = "http://localhost:11434/"; OllamaAPI ollamaAPI = new OllamaAPI(host); @@ -170,78 +231,113 @@ public class FunctionCallingWithMistral { String model = "mistral"; - - MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder() + Tools.ToolSpecification fuelPriceToolSpecification = Tools.ToolSpecification.builder() .functionName("current-fuel-price") - .functionDesc("Get current fuel price") - .props( - new MistralTools.PropsBuilder() - .withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) - .withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build()) + .functionDescription("Get current fuel price") + .properties( + new Tools.PropsBuilder() + .withProperty("location", Tools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) + .withProperty("fuelType", Tools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build()) .build() ) .toolDefinition(SampleTools::getCurrentFuelPrice) .build(); - MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder() + Tools.ToolSpecification weatherToolSpecification = Tools.ToolSpecification.builder() .functionName("current-weather") - .functionDesc("Get current weather") - .props( - new MistralTools.PropsBuilder() - .withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) + .functionDescription("Get current weather") + .properties( + new Tools.PropsBuilder() + .withProperty("city", Tools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) .build() ) .toolDefinition(SampleTools::getCurrentWeather) .build(); + Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get employee details from the database") + .properties( + new Tools.PropsBuilder() + .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()) + .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()) + .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()) + .build() + ) + .toolDefinition(new DBQueryFunction()) + .build(); + ollamaAPI.registerTool(fuelPriceToolSpecification); ollamaAPI.registerTool(weatherToolSpecification); + ollamaAPI.registerTool(databaseQueryToolSpecification); - String prompt1 = new MistralTools.PromptBuilder() + String prompt1 = new Tools.PromptBuilder() .withToolSpecification(fuelPriceToolSpecification) .withToolSpecification(weatherToolSpecification) .withPrompt("What is the petrol price in Bengaluru?") .build(); - String prompt2 = new MistralTools.PromptBuilder() + ask(ollamaAPI, model, prompt1); + + String prompt2 = new Tools.PromptBuilder() .withToolSpecification(fuelPriceToolSpecification) .withToolSpecification(weatherToolSpecification) .withPrompt("What is the current weather in Bengaluru?") .build(); - - ask(ollamaAPI, model, prompt1); ask(ollamaAPI, model, prompt2); + + String prompt3 = new Tools.PromptBuilder() + .withToolSpecification(fuelPriceToolSpecification) + .withToolSpecification(weatherToolSpecification) + .withToolSpecification(databaseQueryToolSpecification) + .withPrompt("Give me the details of the employee named 'Rahul Kumar'?") + .build(); + ask(ollamaAPI, model, prompt3); } - public static void ask(OllamaAPI ollamaAPI, String model, String prompt) throws OllamaBaseException, IOException, InterruptedException { - OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt, false, new OptionsBuilder().build()); - for (Map.Entry r : toolsResult.getToolResults().entrySet()) { - System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString()); + public static void ask(OllamaAPI ollamaAPI, String model, String prompt) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt, new OptionsBuilder().build()); + for (OllamaToolsResult.ToolResult r : toolsResult.getToolResults()) { + System.out.printf("[Result of executing tool '%s']: %s%n", r.getFunctionName(), r.getResult().toString()); } } } + class SampleTools { public static String getCurrentFuelPrice(Map arguments) { + // Get details from fuel price API String location = arguments.get("location").toString(); String fuelType = arguments.get("fuelType").toString(); return "Current price of " + fuelType + " in " + location + " is Rs.103/L"; } public static String getCurrentWeather(Map arguments) { + // Get details from weather API String location = arguments.get("city").toString(); return "Currently " + location + "'s weather is nice."; } } +class DBQueryFunction implements ToolFunction { + @Override + public Object apply(Map arguments) { + // perform DB operations here + return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString()); + } +} ``` Run this full example and you will get a response similar to: ::::tip[LLM Response] -[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L +[Result of executing tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L + +[Result of executing tool 'current-weather']: Currently Bengaluru's weather is nice. + +[Result of executing tool 'get-employee-details']: Employee Details `{ID: 6bad82e6-b1a1-458f-a139-e3b646e092b1, Name: +Rahul Kumar, Address: King St, Hyderabad, India, Phone: 9876543210}` -[Response from tool 'current-weather']: Currently Bengaluru's weather is nice :::: ### Room for improvement diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index d5089ee..516ca19 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -1,6 +1,8 @@ package io.github.amithkoujalgi.ollama4j.core; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; +import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolInvocationException; +import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolNotFoundException; import io.github.amithkoujalgi.ollama4j.core.models.*; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; @@ -14,6 +16,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.request.*; import io.github.amithkoujalgi.ollama4j.core.tools.*; import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; +import lombok.Setter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,10 +40,22 @@ public class OllamaAPI { private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); private final String host; + /** + * -- SETTER -- + * Set request timeout in seconds. Default is 3 seconds. + */ + @Setter private long requestTimeoutSeconds = 10; + /** + * -- SETTER -- + * Set/unset logging of responses + */ + @Setter private boolean verbose = true; private BasicAuth basicAuth; + private final ToolRegistry toolRegistry = new ToolRegistry(); + /** * Instantiates the Ollama API. * @@ -54,24 +69,6 @@ public class OllamaAPI { } } - /** - * Set request timeout in seconds. Default is 3 seconds. - * - * @param requestTimeoutSeconds the request timeout in seconds - */ - public void setRequestTimeoutSeconds(long requestTimeoutSeconds) { - this.requestTimeoutSeconds = requestTimeoutSeconds; - } - - /** - * Set/unset logging of responses - * - * @param verbose true/false - */ - public void setVerbose(boolean verbose) { - this.verbose = verbose; - } - /** * Set basic authentication for accessing Ollama server that's behind a reverse-proxy/gateway. * @@ -383,7 +380,6 @@ public class OllamaAPI { * * @param model The name or identifier of the AI model to use for generating the response. * @param prompt The input text or prompt to provide to the AI model. - * @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context. * @param options Additional options or configurations to use when generating the response. * @return {@link OllamaToolsResult} An OllamaToolsResult object containing the response from the AI model and the results of invoking the tools on that output. * @throws OllamaBaseException If there is an error related to the Ollama API or service. @@ -391,17 +387,23 @@ public class OllamaAPI { * @throws InterruptedException If the method is interrupted while waiting for the AI model * to generate the response or for the tools to be invoked. */ - public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options) - throws OllamaBaseException, IOException, InterruptedException { + public OllamaToolsResult generateWithTools(String model, String prompt, Options options) + throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + boolean raw = true; OllamaToolsResult toolResult = new OllamaToolsResult(); - Map toolResults = new HashMap<>(); + Map toolResults = new HashMap<>(); OllamaResult result = generate(model, prompt, raw, options, null); toolResult.setModelResult(result); - List toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class)); - for (ToolDef toolDef : toolDefs) { - toolResults.put(toolDef, invokeTool(toolDef)); + String toolsResponse = result.getResponse(); + if (toolsResponse.contains("[TOOL_CALLS]")) { + toolsResponse = toolsResponse.replace("[TOOL_CALLS]", ""); + } + + List toolFunctionCallSpecs = Utils.getObjectMapper().readValue(toolsResponse, Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class)); + for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) { + toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec)); } toolResult.setToolResults(toolResults); return toolResult; @@ -556,8 +558,8 @@ public class OllamaAPI { return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); } - public void registerTool(MistralTools.ToolSpecification toolSpecification) { - ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition()); + public void registerTool(Tools.ToolSpecification toolSpecification) { + toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition()); } // technical private methods // @@ -622,18 +624,20 @@ public class OllamaAPI { } - private Object invokeTool(ToolDef toolDef) { + private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException { try { - String methodName = toolDef.getName(); - Map arguments = toolDef.getArguments(); - DynamicFunction function = ToolRegistry.getFunction(methodName); + String methodName = toolFunctionCallSpec.getName(); + Map arguments = toolFunctionCallSpec.getArguments(); + ToolFunction function = toolRegistry.getFunction(methodName); + if (verbose) { + logger.debug("Invoking function {} with arguments {}", methodName, arguments); + } if (function == null) { - throw new IllegalArgumentException("No such tool: " + methodName); + throw new ToolNotFoundException("No such tool: " + methodName); } return function.apply(arguments); } catch (Exception e) { - e.printStackTrace(); - return "Error calling tool: " + e.getMessage(); + throw new ToolInvocationException("Failed to invoke tool: " + toolFunctionCallSpec.getName(), e); } } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/exceptions/ToolInvocationException.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/exceptions/ToolInvocationException.java new file mode 100644 index 0000000..3a1a715 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/exceptions/ToolInvocationException.java @@ -0,0 +1,8 @@ +package io.github.amithkoujalgi.ollama4j.core.exceptions; + +public class ToolInvocationException extends Exception { + + public ToolInvocationException(String s, Exception e) { + super(s, e); + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/exceptions/ToolNotFoundException.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/exceptions/ToolNotFoundException.java new file mode 100644 index 0000000..990400e --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/exceptions/ToolNotFoundException.java @@ -0,0 +1,8 @@ +package io.github.amithkoujalgi.ollama4j.core.exceptions; + +public class ToolNotFoundException extends Exception { + + public ToolNotFoundException(String s) { + super(s); + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java index 65ef3ac..1ff3656 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java @@ -5,6 +5,8 @@ import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; +import java.util.ArrayList; +import java.util.List; import java.util.Map; @Data @@ -12,5 +14,22 @@ import java.util.Map; @AllArgsConstructor public class OllamaToolsResult { private OllamaResult modelResult; - private Map toolResults; + private Map toolResults; + + public List getToolResults() { + List results = new ArrayList<>(); + for (Map.Entry r : this.toolResults.entrySet()) { + results.add(new ToolResult(r.getKey().getName(), r.getKey().getArguments(), r.getValue())); + } + return results; + } + + @Data + @NoArgsConstructor + @AllArgsConstructor + public static class ToolResult { + private String functionName; + private Map functionArguments; + private Object result; + } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/DynamicFunction.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolFunction.java similarity index 80% rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/DynamicFunction.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolFunction.java index 5b8f5e6..d670aa7 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/DynamicFunction.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolFunction.java @@ -3,6 +3,6 @@ package io.github.amithkoujalgi.ollama4j.core.tools; import java.util.Map; @FunctionalInterface -public interface DynamicFunction { +public interface ToolFunction { Object apply(Map arguments); } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolDef.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolFunctionCallSpec.java similarity index 88% rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolDef.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolFunctionCallSpec.java index 751d186..1ce69cb 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolDef.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolFunctionCallSpec.java @@ -9,10 +9,8 @@ import java.util.Map; @Data @AllArgsConstructor @NoArgsConstructor -public class ToolDef { - +public class ToolFunctionCallSpec { private String name; private Map arguments; - } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java index 0004c7f..432a4d7 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java @@ -4,14 +4,13 @@ import java.util.HashMap; import java.util.Map; public class ToolRegistry { - private static final Map functionMap = new HashMap<>(); + private final Map functionMap = new HashMap<>(); - - public static DynamicFunction getFunction(String name) { + public ToolFunction getFunction(String name) { return functionMap.get(name); } - public static void addFunction(String name, DynamicFunction function) { + public void addFunction(String name, ToolFunction function) { functionMap.put(name, function); } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/MistralTools.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/Tools.java similarity index 69% rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/MistralTools.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/Tools.java index fff8071..5315d19 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/MistralTools.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/Tools.java @@ -14,14 +14,14 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -public class MistralTools { +public class Tools { @Data @Builder public static class ToolSpecification { private String functionName; - private String functionDesc; - private Map props; - private DynamicFunction toolDefinition; + private String functionDescription; + private Map properties; + private ToolFunction toolDefinition; } @Data @@ -90,14 +90,14 @@ public class MistralTools { PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec(); functionDetail.setName(spec.getFunctionName()); - functionDetail.setDescription(spec.getFunctionDesc()); + functionDetail.setDescription(spec.getFunctionDescription()); PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters(); parameters.setType("object"); - parameters.setProperties(spec.getProps()); + parameters.setProperties(spec.getProperties()); List requiredValues = new ArrayList<>(); - for (Map.Entry p : spec.getProps().entrySet()) { + for (Map.Entry p : spec.getProperties().entrySet()) { if (p.getValue().isRequired()) { requiredValues.add(p.getKey()); } @@ -109,31 +109,5 @@ public class MistralTools { tools.add(def); return this; } -// -// public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map props) { -// PromptFuncDefinition def = new PromptFuncDefinition(); -// def.setType("function"); -// -// PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec(); -// functionDetail.setName(functionName); -// functionDetail.setDescription(functionDesc); -// -// PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters(); -// parameters.setType("object"); -// parameters.setProperties(props); -// -// List requiredValues = new ArrayList<>(); -// for (Map.Entry p : props.entrySet()) { -// if (p.getValue().isRequired()) { -// requiredValues.add(p.getKey()); -// } -// } -// parameters.setRequired(requiredValues); -// functionDetail.setParameters(parameters); -// def.setFunction(functionDetail); -// -// tools.add(def); -// return this; -// } } }