Refactored tools API

Signed-off-by: koujalgi.amith@gmail.com <koujalgi.amith@gmail.com>
This commit is contained in:
koujalgi.amith@gmail.com 2024-07-14 11:23:36 +05:30
parent fd93036d08
commit 81689be194
9 changed files with 241 additions and 135 deletions

View File

@ -29,8 +29,8 @@ You could do that with ease with the `function calling` capabilities of the mode
### Create Functions ### Create Functions
This function takes the arguments `location` and `fuelType` and performs an operation with these arguments and returns a This function takes the arguments `location` and `fuelType` and performs an operation with these arguments and returns
value. fuel price value.
```java ```java
public static String getCurrentFuelPrice(Map<String, Object> arguments) { public static String getCurrentFuelPrice(Map<String, Object> arguments) {
@ -40,8 +40,8 @@ public static String getCurrentFuelPrice(Map<String, Object> arguments) {
} }
``` ```
This function takes the argument `city` and performs an operation with the argument and returns a This function takes the argument `city` and performs an operation with the argument and returns the weather for a
value. location.
```java ```java
public static String getCurrentWeather(Map<String, Object> arguments) { public static String getCurrentWeather(Map<String, Object> arguments) {
@ -50,6 +50,19 @@ public static String getCurrentWeather(Map<String, Object> arguments) {
} }
``` ```
This function takes the argument `employee-name` and performs an operation with the argument and returns employee
details.
```java
class DBQueryFunction implements ToolFunction {
@Override
public Object apply(Map<String, Object> arguments) {
// perform DB operations here
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString());
}
}
```
### Define Tool Specifications ### Define Tool Specifications
Lets define a sample tool specification called **Fuel Price Tool** for getting the current fuel price. Lets define a sample tool specification called **Fuel Price Tool** for getting the current fuel price.
@ -58,13 +71,13 @@ Lets define a sample tool specification called **Fuel Price Tool** for getting t
- Associate the `getCurrentFuelPrice` function you defined earlier with `SampleTools::getCurrentFuelPrice`. - Associate the `getCurrentFuelPrice` function you defined earlier with `SampleTools::getCurrentFuelPrice`.
```java ```java
MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder() Tools.ToolSpecification fuelPriceToolSpecification = Tools.ToolSpecification.builder()
.functionName("current-fuel-price") .functionName("current-fuel-price")
.functionDesc("Get current fuel price") .functionDescription("Get current fuel price")
.props( .properties(
new MistralTools.PropsBuilder() new Tools.PropsBuilder()
.withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) .withProperty("location", Tools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build()) .withProperty("fuelType", Tools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
.build() .build()
) )
.toolDefinition(SampleTools::getCurrentFuelPrice) .toolDefinition(SampleTools::getCurrentFuelPrice)
@ -77,18 +90,38 @@ Lets also define a sample tool specification called **Weather Tool** for getting
- Associate the `getCurrentWeather` function you defined earlier with `SampleTools::getCurrentWeather`. - Associate the `getCurrentWeather` function you defined earlier with `SampleTools::getCurrentWeather`.
```java ```java
MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder() Tools.ToolSpecification weatherToolSpecification = Tools.ToolSpecification.builder()
.functionName("current-weather") .functionName("current-weather")
.functionDesc("Get current weather") .functionDescription("Get current weather")
.props( .properties(
new MistralTools.PropsBuilder() new Tools.PropsBuilder()
.withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) .withProperty("city", Tools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.build() .build()
) )
.toolDefinition(SampleTools::getCurrentWeather) .toolDefinition(SampleTools::getCurrentWeather)
.build(); .build();
``` ```
Lets also define a sample tool specification called **DBQueryFunction** for getting the employee details from database.
- Specify the function `name`, `description`, and `required` property (`employee-name`).
- Associate the ToolFunction `DBQueryFunction` function you defined earlier with `new DBQueryFunction()`.
```java
Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.properties(
new Tools.PropsBuilder()
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
.build()
)
.toolDefinition(new DBQueryFunction())
.build();
```
### Register the Tools ### Register the Tools
Register the defined tools (`fuel price` and `weather`) with the OllamaAPI. Register the defined tools (`fuel price` and `weather`) with the OllamaAPI.
@ -103,14 +136,14 @@ ollamaAPI.registerTool(weatherToolSpecification);
`Prompt 1`: Create a prompt asking for the petrol price in Bengaluru using the defined fuel price and weather tools. `Prompt 1`: Create a prompt asking for the petrol price in Bengaluru using the defined fuel price and weather tools.
```shell ```shell
String prompt1 = new MistralTools.PromptBuilder() String prompt1 = new Tools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification) .withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification) .withToolSpecification(weatherToolSpecification)
.withPrompt("What is the petrol price in Bengaluru?") .withPrompt("What is the petrol price in Bengaluru?")
.build(); .build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt1, false, new OptionsBuilder().build()); OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt1, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) { for (OllamaToolsResult.ToolResult r : toolsResult.getToolResults()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString()); System.out.printf("[Result of executing tool '%s']: %s%n", r.getFunctionName(), r.getResult().toString());
} }
``` ```
@ -120,21 +153,21 @@ You will get a response similar to:
::::tip[LLM Response] ::::tip[LLM Response]
[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L [Result of executing tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
:::: ::::
`Prompt 2`: Create a prompt asking for the current weather in Bengaluru using the same tools. `Prompt 2`: Create a prompt asking for the current weather in Bengaluru using the same tools.
```shell ```shell
String prompt2 = new MistralTools.PromptBuilder() String prompt2 = new Tools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification) .withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification) .withToolSpecification(weatherToolSpecification)
.withPrompt("What is the current weather in Bengaluru?") .withPrompt("What is the current weather in Bengaluru?")
.build(); .build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt2, false, new OptionsBuilder().build()); OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt2, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) { for (OllamaToolsResult.ToolResult r : toolsResult.getToolResults()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString()); System.out.printf("[Result of executing tool '%s']: %s%n", r.getFunctionName(), r.getResult().toString());
} }
``` ```
@ -144,25 +177,53 @@ You will get a response similar to:
::::tip[LLM Response] ::::tip[LLM Response]
[Response from tool 'current-weather']: Currently Bengaluru's weather is nice [Result of executing tool 'current-weather']: Currently Bengaluru's weather is nice.
::::
`Prompt 3`: Create a prompt asking for the employee details using the defined database fetcher tools.
```shell
String prompt3 = new Tools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withToolSpecification(databaseQueryToolSpecification)
.withPrompt("Give me the details of the employee named 'Rahul Kumar'?")
.build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt3, new OptionsBuilder().build());
for (OllamaToolsResult.ToolResult r : toolsResult.getToolResults()) {
System.out.printf("[Result of executing tool '%s']: %s%n", r.getFunctionName(), r.getResult().toString());
}
```
Again, fire away your question to the model.
You will get a response similar to:
::::tip[LLM Response]
[Result of executing tool 'get-employee-details']: Employee Details `{ID: 6bad82e6-b1a1-458f-a139-e3b646e092b1, Name:
Rahul Kumar, Address: King St, Hyderabad, India, Phone: 9876543210}`
:::: ::::
### Full Example ### Full Example
```java ```java
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.tools.ToolDef; import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolInvocationException;
import io.github.amithkoujalgi.ollama4j.core.tools.MistralTools;
import io.github.amithkoujalgi.ollama4j.core.tools.OllamaToolsResult; import io.github.amithkoujalgi.ollama4j.core.tools.OllamaToolsResult;
import io.github.amithkoujalgi.ollama4j.core.tools.ToolFunction;
import io.github.amithkoujalgi.ollama4j.core.tools.Tools;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Map; import java.util.Map;
import java.util.UUID;
public class FunctionCallingWithMistral { public class FunctionCallingWithMistralExample {
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
@ -170,78 +231,113 @@ public class FunctionCallingWithMistral {
String model = "mistral"; String model = "mistral";
Tools.ToolSpecification fuelPriceToolSpecification = Tools.ToolSpecification.builder()
MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-fuel-price") .functionName("current-fuel-price")
.functionDesc("Get current fuel price") .functionDescription("Get current fuel price")
.props( .properties(
new MistralTools.PropsBuilder() new Tools.PropsBuilder()
.withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) .withProperty("location", Tools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build()) .withProperty("fuelType", Tools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
.build() .build()
) )
.toolDefinition(SampleTools::getCurrentFuelPrice) .toolDefinition(SampleTools::getCurrentFuelPrice)
.build(); .build();
MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder() Tools.ToolSpecification weatherToolSpecification = Tools.ToolSpecification.builder()
.functionName("current-weather") .functionName("current-weather")
.functionDesc("Get current weather") .functionDescription("Get current weather")
.props( .properties(
new MistralTools.PropsBuilder() new Tools.PropsBuilder()
.withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) .withProperty("city", Tools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.build() .build()
) )
.toolDefinition(SampleTools::getCurrentWeather) .toolDefinition(SampleTools::getCurrentWeather)
.build(); .build();
Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.properties(
new Tools.PropsBuilder()
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
.build()
)
.toolDefinition(new DBQueryFunction())
.build();
ollamaAPI.registerTool(fuelPriceToolSpecification); ollamaAPI.registerTool(fuelPriceToolSpecification);
ollamaAPI.registerTool(weatherToolSpecification); ollamaAPI.registerTool(weatherToolSpecification);
ollamaAPI.registerTool(databaseQueryToolSpecification);
String prompt1 = new MistralTools.PromptBuilder() String prompt1 = new Tools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification) .withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification) .withToolSpecification(weatherToolSpecification)
.withPrompt("What is the petrol price in Bengaluru?") .withPrompt("What is the petrol price in Bengaluru?")
.build(); .build();
String prompt2 = new MistralTools.PromptBuilder() ask(ollamaAPI, model, prompt1);
String prompt2 = new Tools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification) .withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification) .withToolSpecification(weatherToolSpecification)
.withPrompt("What is the current weather in Bengaluru?") .withPrompt("What is the current weather in Bengaluru?")
.build(); .build();
ask(ollamaAPI, model, prompt1);
ask(ollamaAPI, model, prompt2); ask(ollamaAPI, model, prompt2);
String prompt3 = new Tools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withToolSpecification(databaseQueryToolSpecification)
.withPrompt("Give me the details of the employee named 'Rahul Kumar'?")
.build();
ask(ollamaAPI, model, prompt3);
} }
public static void ask(OllamaAPI ollamaAPI, String model, String prompt) throws OllamaBaseException, IOException, InterruptedException { public static void ask(OllamaAPI ollamaAPI, String model, String prompt) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt, false, new OptionsBuilder().build()); OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) { for (OllamaToolsResult.ToolResult r : toolsResult.getToolResults()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString()); System.out.printf("[Result of executing tool '%s']: %s%n", r.getFunctionName(), r.getResult().toString());
} }
} }
} }
class SampleTools { class SampleTools {
public static String getCurrentFuelPrice(Map<String, Object> arguments) { public static String getCurrentFuelPrice(Map<String, Object> arguments) {
// Get details from fuel price API
String location = arguments.get("location").toString(); String location = arguments.get("location").toString();
String fuelType = arguments.get("fuelType").toString(); String fuelType = arguments.get("fuelType").toString();
return "Current price of " + fuelType + " in " + location + " is Rs.103/L"; return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
} }
public static String getCurrentWeather(Map<String, Object> arguments) { public static String getCurrentWeather(Map<String, Object> arguments) {
// Get details from weather API
String location = arguments.get("city").toString(); String location = arguments.get("city").toString();
return "Currently " + location + "'s weather is nice."; return "Currently " + location + "'s weather is nice.";
} }
} }
class DBQueryFunction implements ToolFunction {
@Override
public Object apply(Map<String, Object> arguments) {
// perform DB operations here
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString());
}
}
``` ```
Run this full example and you will get a response similar to: Run this full example and you will get a response similar to:
::::tip[LLM Response] ::::tip[LLM Response]
[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L [Result of executing tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
[Result of executing tool 'current-weather']: Currently Bengaluru's weather is nice.
[Result of executing tool 'get-employee-details']: Employee Details `{ID: 6bad82e6-b1a1-458f-a139-e3b646e092b1, Name:
Rahul Kumar, Address: King St, Hyderabad, India, Phone: 9876543210}`
[Response from tool 'current-weather']: Currently Bengaluru's weather is nice
:::: ::::
### Room for improvement ### Room for improvement

View File

@ -1,6 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core; package io.github.amithkoujalgi.ollama4j.core;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolInvocationException;
import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolNotFoundException;
import io.github.amithkoujalgi.ollama4j.core.models.*; import io.github.amithkoujalgi.ollama4j.core.models.*;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
@ -14,6 +16,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.request.*;
import io.github.amithkoujalgi.ollama4j.core.tools.*; import io.github.amithkoujalgi.ollama4j.core.tools.*;
import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import lombok.Setter;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -37,10 +40,22 @@ public class OllamaAPI {
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
private final String host; private final String host;
/**
* -- SETTER --
* Set request timeout in seconds. Default is 3 seconds.
*/
@Setter
private long requestTimeoutSeconds = 10; private long requestTimeoutSeconds = 10;
/**
* -- SETTER --
* Set/unset logging of responses
*/
@Setter
private boolean verbose = true; private boolean verbose = true;
private BasicAuth basicAuth; private BasicAuth basicAuth;
private final ToolRegistry toolRegistry = new ToolRegistry();
/** /**
* Instantiates the Ollama API. * Instantiates the Ollama API.
* *
@ -54,24 +69,6 @@ public class OllamaAPI {
} }
} }
/**
* Set request timeout in seconds. Default is 3 seconds.
*
* @param requestTimeoutSeconds the request timeout in seconds
*/
public void setRequestTimeoutSeconds(long requestTimeoutSeconds) {
this.requestTimeoutSeconds = requestTimeoutSeconds;
}
/**
* Set/unset logging of responses
*
* @param verbose true/false
*/
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
/** /**
* Set basic authentication for accessing Ollama server that's behind a reverse-proxy/gateway. * Set basic authentication for accessing Ollama server that's behind a reverse-proxy/gateway.
* *
@ -383,7 +380,6 @@ public class OllamaAPI {
* *
* @param model The name or identifier of the AI model to use for generating the response. * @param model The name or identifier of the AI model to use for generating the response.
* @param prompt The input text or prompt to provide to the AI model. * @param prompt The input text or prompt to provide to the AI model.
* @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context.
* @param options Additional options or configurations to use when generating the response. * @param options Additional options or configurations to use when generating the response.
* @return {@link OllamaToolsResult} An OllamaToolsResult object containing the response from the AI model and the results of invoking the tools on that output. * @return {@link OllamaToolsResult} An OllamaToolsResult object containing the response from the AI model and the results of invoking the tools on that output.
* @throws OllamaBaseException If there is an error related to the Ollama API or service. * @throws OllamaBaseException If there is an error related to the Ollama API or service.
@ -391,17 +387,23 @@ public class OllamaAPI {
* @throws InterruptedException If the method is interrupted while waiting for the AI model * @throws InterruptedException If the method is interrupted while waiting for the AI model
* to generate the response or for the tools to be invoked. * to generate the response or for the tools to be invoked.
*/ */
public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options) public OllamaToolsResult generateWithTools(String model, String prompt, Options options)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
boolean raw = true;
OllamaToolsResult toolResult = new OllamaToolsResult(); OllamaToolsResult toolResult = new OllamaToolsResult();
Map<ToolDef, Object> toolResults = new HashMap<>(); Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
OllamaResult result = generate(model, prompt, raw, options, null); OllamaResult result = generate(model, prompt, raw, options, null);
toolResult.setModelResult(result); toolResult.setModelResult(result);
List<ToolDef> toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class)); String toolsResponse = result.getResponse();
for (ToolDef toolDef : toolDefs) { if (toolsResponse.contains("[TOOL_CALLS]")) {
toolResults.put(toolDef, invokeTool(toolDef)); toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
}
List<ToolFunctionCallSpec> toolFunctionCallSpecs = Utils.getObjectMapper().readValue(toolsResponse, Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
} }
toolResult.setToolResults(toolResults); toolResult.setToolResults(toolResults);
return toolResult; return toolResult;
@ -556,8 +558,8 @@ public class OllamaAPI {
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
} }
public void registerTool(MistralTools.ToolSpecification toolSpecification) { public void registerTool(Tools.ToolSpecification toolSpecification) {
ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition()); toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
} }
// technical private methods // // technical private methods //
@ -622,18 +624,20 @@ public class OllamaAPI {
} }
private Object invokeTool(ToolDef toolDef) { private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
try { try {
String methodName = toolDef.getName(); String methodName = toolFunctionCallSpec.getName();
Map<String, Object> arguments = toolDef.getArguments(); Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
DynamicFunction function = ToolRegistry.getFunction(methodName); ToolFunction function = toolRegistry.getFunction(methodName);
if (verbose) {
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
}
if (function == null) { if (function == null) {
throw new IllegalArgumentException("No such tool: " + methodName); throw new ToolNotFoundException("No such tool: " + methodName);
} }
return function.apply(arguments); return function.apply(arguments);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); throw new ToolInvocationException("Failed to invoke tool: " + toolFunctionCallSpec.getName(), e);
return "Error calling tool: " + e.getMessage();
} }
} }
} }

View File

@ -0,0 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.exceptions;
public class ToolInvocationException extends Exception {
public ToolInvocationException(String s, Exception e) {
super(s, e);
}
}

View File

@ -0,0 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.exceptions;
public class ToolNotFoundException extends Exception {
public ToolNotFoundException(String s) {
super(s);
}
}

View File

@ -5,6 +5,8 @@ import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Map; import java.util.Map;
@Data @Data
@ -12,5 +14,22 @@ import java.util.Map;
@AllArgsConstructor @AllArgsConstructor
public class OllamaToolsResult { public class OllamaToolsResult {
private OllamaResult modelResult; private OllamaResult modelResult;
private Map<ToolDef, Object> toolResults; private Map<ToolFunctionCallSpec, Object> toolResults;
public List<ToolResult> getToolResults() {
List<ToolResult> results = new ArrayList<>();
for (Map.Entry<ToolFunctionCallSpec, Object> r : this.toolResults.entrySet()) {
results.add(new ToolResult(r.getKey().getName(), r.getKey().getArguments(), r.getValue()));
}
return results;
}
@Data
@NoArgsConstructor
@AllArgsConstructor
public static class ToolResult {
private String functionName;
private Map<String, Object> functionArguments;
private Object result;
}
} }

View File

@ -3,6 +3,6 @@ package io.github.amithkoujalgi.ollama4j.core.tools;
import java.util.Map; import java.util.Map;
@FunctionalInterface @FunctionalInterface
public interface DynamicFunction { public interface ToolFunction {
Object apply(Map<String, Object> arguments); Object apply(Map<String, Object> arguments);
} }

View File

@ -9,10 +9,8 @@ import java.util.Map;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
public class ToolDef { public class ToolFunctionCallSpec {
private String name; private String name;
private Map<String, Object> arguments; private Map<String, Object> arguments;
} }

View File

@ -4,14 +4,13 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
public class ToolRegistry { public class ToolRegistry {
private static final Map<String, DynamicFunction> functionMap = new HashMap<>(); private final Map<String, ToolFunction> functionMap = new HashMap<>();
public ToolFunction getFunction(String name) {
public static DynamicFunction getFunction(String name) {
return functionMap.get(name); return functionMap.get(name);
} }
public static void addFunction(String name, DynamicFunction function) { public void addFunction(String name, ToolFunction function) {
functionMap.put(name, function); functionMap.put(name, function);
} }
} }

View File

@ -14,14 +14,14 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
public class MistralTools { public class Tools {
@Data @Data
@Builder @Builder
public static class ToolSpecification { public static class ToolSpecification {
private String functionName; private String functionName;
private String functionDesc; private String functionDescription;
private Map<String, PromptFuncDefinition.Property> props; private Map<String, PromptFuncDefinition.Property> properties;
private DynamicFunction toolDefinition; private ToolFunction toolDefinition;
} }
@Data @Data
@ -90,14 +90,14 @@ public class MistralTools {
PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec(); PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
functionDetail.setName(spec.getFunctionName()); functionDetail.setName(spec.getFunctionName());
functionDetail.setDescription(spec.getFunctionDesc()); functionDetail.setDescription(spec.getFunctionDescription());
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters(); PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
parameters.setType("object"); parameters.setType("object");
parameters.setProperties(spec.getProps()); parameters.setProperties(spec.getProperties());
List<String> requiredValues = new ArrayList<>(); List<String> requiredValues = new ArrayList<>();
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProps().entrySet()) { for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) {
if (p.getValue().isRequired()) { if (p.getValue().isRequired()) {
requiredValues.add(p.getKey()); requiredValues.add(p.getKey());
} }
@ -109,31 +109,5 @@ public class MistralTools {
tools.add(def); tools.add(def);
return this; return this;
} }
//
// public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map<String, PromptFuncDefinition.Property> props) {
// PromptFuncDefinition def = new PromptFuncDefinition();
// def.setType("function");
//
// PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
// functionDetail.setName(functionName);
// functionDetail.setDescription(functionDesc);
//
// PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
// parameters.setType("object");
// parameters.setProperties(props);
//
// List<String> requiredValues = new ArrayList<>();
// for (Map.Entry<String, PromptFuncDefinition.Property> p : props.entrySet()) {
// if (p.getValue().isRequired()) {
// requiredValues.add(p.getKey());
// }
// }
// parameters.setRequired(requiredValues);
// functionDetail.setParameters(parameters);
// def.setFunction(functionDetail);
//
// tools.add(def);
// return this;
// }
} }
} }