Refactored tools API

Signed-off-by: koujalgi.amith@gmail.com <koujalgi.amith@gmail.com>
This commit is contained in:
koujalgi.amith@gmail.com
2024-07-14 11:23:36 +05:30
parent fd93036d08
commit 81689be194
9 changed files with 241 additions and 135 deletions

View File

@@ -1,6 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolInvocationException;
import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolNotFoundException;
import io.github.amithkoujalgi.ollama4j.core.models.*;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
@@ -14,6 +16,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.request.*;
import io.github.amithkoujalgi.ollama4j.core.tools.*;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import lombok.Setter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -37,10 +40,22 @@ public class OllamaAPI {
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
private final String host;
/**
* -- SETTER --
* Set request timeout in seconds. Default is 3 seconds.
*/
@Setter
private long requestTimeoutSeconds = 10;
/**
* -- SETTER --
* Set/unset logging of responses
*/
@Setter
private boolean verbose = true;
private BasicAuth basicAuth;
private final ToolRegistry toolRegistry = new ToolRegistry();
/**
* Instantiates the Ollama API.
*
@@ -54,24 +69,6 @@ public class OllamaAPI {
}
}
/**
* Set request timeout in seconds. Default is 3 seconds.
*
* @param requestTimeoutSeconds the request timeout in seconds
*/
public void setRequestTimeoutSeconds(long requestTimeoutSeconds) {
this.requestTimeoutSeconds = requestTimeoutSeconds;
}
/**
* Set/unset logging of responses
*
* @param verbose true/false
*/
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
/**
* Set basic authentication for accessing Ollama server that's behind a reverse-proxy/gateway.
*
@@ -383,7 +380,6 @@ public class OllamaAPI {
*
* @param model The name or identifier of the AI model to use for generating the response.
* @param prompt The input text or prompt to provide to the AI model.
* @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context.
* @param options Additional options or configurations to use when generating the response.
* @return {@link OllamaToolsResult} An OllamaToolsResult object containing the response from the AI model and the results of invoking the tools on that output.
* @throws OllamaBaseException If there is an error related to the Ollama API or service.
@@ -391,17 +387,23 @@ public class OllamaAPI {
* @throws InterruptedException If the method is interrupted while waiting for the AI model
* to generate the response or for the tools to be invoked.
*/
public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options)
throws OllamaBaseException, IOException, InterruptedException {
public OllamaToolsResult generateWithTools(String model, String prompt, Options options)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
boolean raw = true;
OllamaToolsResult toolResult = new OllamaToolsResult();
Map<ToolDef, Object> toolResults = new HashMap<>();
Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
OllamaResult result = generate(model, prompt, raw, options, null);
toolResult.setModelResult(result);
List<ToolDef> toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class));
for (ToolDef toolDef : toolDefs) {
toolResults.put(toolDef, invokeTool(toolDef));
String toolsResponse = result.getResponse();
if (toolsResponse.contains("[TOOL_CALLS]")) {
toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
}
List<ToolFunctionCallSpec> toolFunctionCallSpecs = Utils.getObjectMapper().readValue(toolsResponse, Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
}
toolResult.setToolResults(toolResults);
return toolResult;
@@ -556,8 +558,8 @@ public class OllamaAPI {
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
}
public void registerTool(MistralTools.ToolSpecification toolSpecification) {
ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
public void registerTool(Tools.ToolSpecification toolSpecification) {
toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
}
// technical private methods //
@@ -622,18 +624,20 @@ public class OllamaAPI {
}
private Object invokeTool(ToolDef toolDef) {
private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
try {
String methodName = toolDef.getName();
Map<String, Object> arguments = toolDef.getArguments();
DynamicFunction function = ToolRegistry.getFunction(methodName);
String methodName = toolFunctionCallSpec.getName();
Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
ToolFunction function = toolRegistry.getFunction(methodName);
if (verbose) {
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
}
if (function == null) {
throw new IllegalArgumentException("No such tool: " + methodName);
throw new ToolNotFoundException("No such tool: " + methodName);
}
return function.apply(arguments);
} catch (Exception e) {
e.printStackTrace();
return "Error calling tool: " + e.getMessage();
throw new ToolInvocationException("Failed to invoke tool: " + toolFunctionCallSpec.getName(), e);
}
}
}

View File

@@ -0,0 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.exceptions;
public class ToolInvocationException extends Exception {
public ToolInvocationException(String s, Exception e) {
super(s, e);
}
}

View File

@@ -0,0 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.exceptions;
public class ToolNotFoundException extends Exception {
public ToolNotFoundException(String s) {
super(s);
}
}

View File

@@ -5,6 +5,8 @@ import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@Data
@@ -12,5 +14,22 @@ import java.util.Map;
@AllArgsConstructor
public class OllamaToolsResult {
private OllamaResult modelResult;
private Map<ToolDef, Object> toolResults;
private Map<ToolFunctionCallSpec, Object> toolResults;
public List<ToolResult> getToolResults() {
List<ToolResult> results = new ArrayList<>();
for (Map.Entry<ToolFunctionCallSpec, Object> r : this.toolResults.entrySet()) {
results.add(new ToolResult(r.getKey().getName(), r.getKey().getArguments(), r.getValue()));
}
return results;
}
@Data
@NoArgsConstructor
@AllArgsConstructor
public static class ToolResult {
private String functionName;
private Map<String, Object> functionArguments;
private Object result;
}
}

View File

@@ -3,6 +3,6 @@ package io.github.amithkoujalgi.ollama4j.core.tools;
import java.util.Map;
@FunctionalInterface
public interface DynamicFunction {
public interface ToolFunction {
Object apply(Map<String, Object> arguments);
}

View File

@@ -9,10 +9,8 @@ import java.util.Map;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ToolDef {
public class ToolFunctionCallSpec {
private String name;
private Map<String, Object> arguments;
}

View File

@@ -4,14 +4,13 @@ import java.util.HashMap;
import java.util.Map;
public class ToolRegistry {
private static final Map<String, DynamicFunction> functionMap = new HashMap<>();
private final Map<String, ToolFunction> functionMap = new HashMap<>();
public static DynamicFunction getFunction(String name) {
public ToolFunction getFunction(String name) {
return functionMap.get(name);
}
public static void addFunction(String name, DynamicFunction function) {
public void addFunction(String name, ToolFunction function) {
functionMap.put(name, function);
}
}

View File

@@ -14,14 +14,14 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class MistralTools {
public class Tools {
@Data
@Builder
public static class ToolSpecification {
private String functionName;
private String functionDesc;
private Map<String, PromptFuncDefinition.Property> props;
private DynamicFunction toolDefinition;
private String functionDescription;
private Map<String, PromptFuncDefinition.Property> properties;
private ToolFunction toolDefinition;
}
@Data
@@ -90,14 +90,14 @@ public class MistralTools {
PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
functionDetail.setName(spec.getFunctionName());
functionDetail.setDescription(spec.getFunctionDesc());
functionDetail.setDescription(spec.getFunctionDescription());
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
parameters.setType("object");
parameters.setProperties(spec.getProps());
parameters.setProperties(spec.getProperties());
List<String> requiredValues = new ArrayList<>();
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProps().entrySet()) {
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) {
if (p.getValue().isRequired()) {
requiredValues.add(p.getKey());
}
@@ -109,31 +109,5 @@ public class MistralTools {
tools.add(def);
return this;
}
//
// public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map<String, PromptFuncDefinition.Property> props) {
// PromptFuncDefinition def = new PromptFuncDefinition();
// def.setType("function");
//
// PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
// functionDetail.setName(functionName);
// functionDetail.setDescription(functionDesc);
//
// PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
// parameters.setType("object");
// parameters.setProperties(props);
//
// List<String> requiredValues = new ArrayList<>();
// for (Map.Entry<String, PromptFuncDefinition.Property> p : props.entrySet()) {
// if (p.getValue().isRequired()) {
// requiredValues.add(p.getKey());
// }
// }
// parameters.setRequired(requiredValues);
// functionDetail.setParameters(parameters);
// def.setFunction(functionDetail);
//
// tools.add(def);
// return this;
// }
}
}