Refactor OllamaAPI and related classes to enhance tool management and request handling

This update modifies the OllamaAPI class and associated request classes to improve the handling of tools. The ToolRegistry now manages a list of Tools.Tool objects instead of ToolSpecification, streamlining tool registration and retrieval. The OllamaGenerateRequest and OllamaChatRequest classes have been updated to reflect this change, ensuring consistency across the API. Additionally, several deprecated methods and commented-out code have been removed for clarity. Integration tests have been adjusted to accommodate these changes, enhancing overall test reliability.
This commit is contained in:
amithkoujalgi 2025-09-26 01:26:22 +05:30
parent fe82550637
commit f5ca5bdca3
No known key found for this signature in database
GPG Key ID: E29A37746AF94B70
11 changed files with 2264 additions and 2136 deletions

View File

@ -12,7 +12,6 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.exceptions.RoleNotFoundException; import io.github.ollama4j.exceptions.RoleNotFoundException;
import io.github.ollama4j.exceptions.ToolInvocationException; import io.github.ollama4j.exceptions.ToolInvocationException;
import io.github.ollama4j.exceptions.ToolNotFoundException;
import io.github.ollama4j.metrics.MetricsRecorder; import io.github.ollama4j.metrics.MetricsRecorder;
import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.chat.OllamaChatTokenHandler; import io.github.ollama4j.models.chat.OllamaChatTokenHandler;
@ -25,15 +24,9 @@ import io.github.ollama4j.models.ps.ModelsProcessResponse;
import io.github.ollama4j.models.request.*; import io.github.ollama4j.models.request.*;
import io.github.ollama4j.models.response.*; import io.github.ollama4j.models.response.*;
import io.github.ollama4j.tools.*; import io.github.ollama4j.tools.*;
import io.github.ollama4j.tools.annotations.OllamaToolService;
import io.github.ollama4j.tools.annotations.ToolProperty;
import io.github.ollama4j.tools.annotations.ToolSpec;
import io.github.ollama4j.utils.Constants; import io.github.ollama4j.utils.Constants;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import java.io.*; import java.io.*;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.net.http.HttpClient; import java.net.http.HttpClient;
@ -61,6 +54,7 @@ public class OllamaAPI {
private final String host; private final String host;
private Auth auth; private Auth auth;
private final ToolRegistry toolRegistry = new ToolRegistry(); private final ToolRegistry toolRegistry = new ToolRegistry();
/** /**
@ -760,10 +754,10 @@ public class OllamaAPI {
private OllamaResult generateWithToolsInternal( private OllamaResult generateWithToolsInternal(
OllamaGenerateRequest request, OllamaGenerateStreamObserver streamObserver) OllamaGenerateRequest request, OllamaGenerateStreamObserver streamObserver)
throws OllamaBaseException { throws OllamaBaseException {
List<Tools.PromptFuncDefinition> tools = new ArrayList<>(); // List<Tools.PromptFuncDefinition> tools = new ArrayList<>();
for (Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) { // for (Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) {
tools.add(spec.getToolPrompt()); // tools.add(spec.getToolPrompt());
} // }
ArrayList<OllamaChatMessage> msgs = new ArrayList<>(); ArrayList<OllamaChatMessage> msgs = new ArrayList<>();
OllamaChatRequest chatRequest = new OllamaChatRequest(); OllamaChatRequest chatRequest = new OllamaChatRequest();
chatRequest.setModel(request.getModel()); chatRequest.setModel(request.getModel());
@ -773,15 +767,17 @@ public class OllamaAPI {
chatRequest.setMessages(msgs); chatRequest.setMessages(msgs);
msgs.add(ocm); msgs.add(ocm);
OllamaChatTokenHandler hdlr = null; OllamaChatTokenHandler hdlr = null;
chatRequest.setTools(tools); chatRequest.setTools(request.getTools());
if (streamObserver != null) { if (streamObserver != null) {
chatRequest.setStream(true); chatRequest.setStream(true);
if (streamObserver.getResponseStreamHandler() != null) {
hdlr = hdlr =
chatResponseModel -> chatResponseModel ->
streamObserver streamObserver
.getResponseStreamHandler() .getResponseStreamHandler()
.accept(chatResponseModel.getMessage().getResponse()); .accept(chatResponseModel.getMessage().getResponse());
} }
}
OllamaChatResult res = chat(chatRequest, hdlr); OllamaChatResult res = chat(chatRequest, hdlr);
return new OllamaResult( return new OllamaResult(
res.getResponseModel().getMessage().getResponse(), res.getResponseModel().getMessage().getResponse(),
@ -837,10 +833,8 @@ public class OllamaAPI {
// only add tools if tools flag is set // only add tools if tools flag is set
if (request.isUseTools()) { if (request.isUseTools()) {
// add all registered tools to request // add all registered tools to request
request.setTools( request.setTools(toolRegistry.getRegisteredTools());
toolRegistry.getRegisteredSpecs().stream() System.out.println("Use tools is set.");
.map(Tools.ToolSpecification::getToolPrompt)
.collect(Collectors.toList()));
} }
if (tokenHandler != null) { if (tokenHandler != null) {
@ -859,9 +853,12 @@ public class OllamaAPI {
&& toolCallTries < maxChatToolCallRetries) { && toolCallTries < maxChatToolCallRetries) {
for (OllamaChatToolCalls toolCall : toolCalls) { for (OllamaChatToolCalls toolCall : toolCalls) {
String toolName = toolCall.getFunction().getName(); String toolName = toolCall.getFunction().getName();
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName); for (Tools.Tool t : request.getTools()) {
if (t.getToolSpec().getName().equals(toolName)) {
ToolFunction toolFunction = t.getToolFunction();
if (toolFunction == null) { if (toolFunction == null) {
throw new ToolInvocationException("Tool function not found: " + toolName); throw new ToolInvocationException(
"Tool function not found: " + toolName);
} }
LOG.debug( LOG.debug(
"Invoking tool {} with arguments: {}", "Invoking tool {} with arguments: {}",
@ -885,6 +882,8 @@ public class OllamaAPI {
+ res + res
+ " [/TOOL_RESULTS]")); + " [/TOOL_RESULTS]"));
} }
}
}
if (tokenHandler != null) { if (tokenHandler != null) {
result = requestCaller.call(request, tokenHandler); result = requestCaller.call(request, tokenHandler);
} else { } else {
@ -900,27 +899,23 @@ public class OllamaAPI {
} }
/** /**
* Registers a single tool in the tool registry using the provided tool specification. * Registers a single tool in the tool registry.
* *
* @param toolSpecification the specification of the tool to register. It contains the tool's * @param tool the tool to register. Contains the tool's specification and function.
* function name and other relevant information.
*/ */
public void registerTool(Tools.ToolSpecification toolSpecification) { public void registerTool(Tools.Tool tool) {
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); toolRegistry.addTool(tool);
LOG.debug("Registered tool: {}", toolSpecification.getFunctionName()); LOG.debug("Registered tool: {}", tool.getToolSpec().getName());
} }
/** /**
* Registers multiple tools in the tool registry using a list of tool specifications. Iterates * Registers multiple tools in the tool registry.
* over the list and adds each tool specification to the registry.
* *
* @param toolSpecifications a list of tool specifications to register. Each specification * @param tools a list of {@link Tools.Tool} objects to register. Each tool contains
* contains information about a tool, such as its function name. * its specification and function.
*/ */
public void registerTools(List<Tools.ToolSpecification> toolSpecifications) { public void registerTools(List<Tools.Tool> tools) {
for (Tools.ToolSpecification toolSpecification : toolSpecifications) { toolRegistry.addTools(tools);
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
}
} }
/** /**
@ -932,122 +927,135 @@ public class OllamaAPI {
LOG.debug("All tools have been deregistered."); LOG.debug("All tools have been deregistered.");
} }
/** //
* Registers tools based on the annotations found on the methods of the caller's class and its // /**
* providers. This method scans the caller's class for the {@link OllamaToolService} annotation // * Registers tools based on the annotations found on the methods of the caller's class and
* and recursively registers annotated tools from all the providers specified in the annotation. // its
* // * providers. This method scans the caller's class for the {@link OllamaToolService}
* @throws OllamaBaseException if the caller's class is not annotated with {@link // annotation
* OllamaToolService} or if reflection-based instantiation or invocation fails // * and recursively registers annotated tools from all the providers specified in the
*/ // annotation.
public void registerAnnotatedTools() throws OllamaBaseException { // *
try { // * @throws OllamaBaseException if the caller's class is not annotated with {@link
Class<?> callerClass = null; // * OllamaToolService} or if reflection-based instantiation or invocation fails
try { // */
callerClass = // public void registerAnnotatedTools() throws OllamaBaseException {
Class.forName(Thread.currentThread().getStackTrace()[2].getClassName()); // try {
} catch (ClassNotFoundException e) { // Class<?> callerClass = null;
throw new OllamaBaseException(e.getMessage(), e); // try {
} // callerClass =
//
OllamaToolService ollamaToolServiceAnnotation = // Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
callerClass.getDeclaredAnnotation(OllamaToolService.class); // } catch (ClassNotFoundException e) {
if (ollamaToolServiceAnnotation == null) { // throw new OllamaBaseException(e.getMessage(), e);
throw new IllegalStateException( // }
callerClass + " is not annotated as " + OllamaToolService.class); //
} // OllamaToolService ollamaToolServiceAnnotation =
// callerClass.getDeclaredAnnotation(OllamaToolService.class);
Class<?>[] providers = ollamaToolServiceAnnotation.providers(); // if (ollamaToolServiceAnnotation == null) {
for (Class<?> provider : providers) { // throw new IllegalStateException(
registerAnnotatedTools(provider.getDeclaredConstructor().newInstance()); // callerClass + " is not annotated as " + OllamaToolService.class);
} // }
} catch (InstantiationException //
| NoSuchMethodException // Class<?>[] providers = ollamaToolServiceAnnotation.providers();
| IllegalAccessException // for (Class<?> provider : providers) {
| InvocationTargetException e) { // registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
throw new OllamaBaseException(e.getMessage()); // }
} // } catch (InstantiationException
} // | NoSuchMethodException
// | IllegalAccessException
/** // | InvocationTargetException e) {
* Registers tools based on the annotations found on the methods of the provided object. This // throw new OllamaBaseException(e.getMessage());
* method scans the methods of the given object and registers tools using the {@link ToolSpec} // }
* annotation and associated {@link ToolProperty} annotations. It constructs tool specifications // }
* and stores them in a tool registry. //
* // /**
* @param object the object whose methods are to be inspected for annotated tools // * Registers tools based on the annotations found on the methods of the provided object.
* @throws RuntimeException if any reflection-based instantiation or invocation fails // This
*/ // * method scans the methods of the given object and registers tools using the {@link
public void registerAnnotatedTools(Object object) { // ToolSpec}
Class<?> objectClass = object.getClass(); // * annotation and associated {@link ToolProperty} annotations. It constructs tool
Method[] methods = objectClass.getMethods(); // specifications
for (Method m : methods) { // * and stores them in a tool registry.
ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class); // *
if (toolSpec == null) { // * @param object the object whose methods are to be inspected for annotated tools
continue; // * @throws RuntimeException if any reflection-based instantiation or invocation fails
} // */
String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName(); // public void registerAnnotatedTools(Object object) {
String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName; // Class<?> objectClass = object.getClass();
// Method[] methods = objectClass.getMethods();
final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder(); // for (Method m : methods) {
LinkedHashMap<String, String> methodParams = new LinkedHashMap<>(); // ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
for (Parameter parameter : m.getParameters()) { // if (toolSpec == null) {
final ToolProperty toolPropertyAnn = // continue;
parameter.getDeclaredAnnotation(ToolProperty.class); // }
String propType = parameter.getType().getTypeName(); // String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
if (toolPropertyAnn == null) { // String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() :
methodParams.put(parameter.getName(), null); // operationName;
continue; //
} // final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
String propName = // LinkedHashMap<String, String> methodParams = new LinkedHashMap<>();
!toolPropertyAnn.name().isBlank() // for (Parameter parameter : m.getParameters()) {
? toolPropertyAnn.name() // final ToolProperty toolPropertyAnn =
: parameter.getName(); // parameter.getDeclaredAnnotation(ToolProperty.class);
methodParams.put(propName, propType); // String propType = parameter.getType().getTypeName();
propsBuilder.withProperty( // if (toolPropertyAnn == null) {
propName, // methodParams.put(parameter.getName(), null);
Tools.PromptFuncDefinition.Property.builder() // continue;
.type(propType) // }
.description(toolPropertyAnn.desc()) // String propName =
.required(toolPropertyAnn.required()) // !toolPropertyAnn.name().isBlank()
.build()); // ? toolPropertyAnn.name()
} // : parameter.getName();
final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build(); // methodParams.put(propName, propType);
List<String> reqProps = // propsBuilder.withProperty(
params.entrySet().stream() // propName,
.filter(e -> e.getValue().isRequired()) // Tools.PromptFuncDefinition.Property.builder()
.map(Map.Entry::getKey) // .type(propType)
.collect(Collectors.toList()); // .description(toolPropertyAnn.desc())
// .required(toolPropertyAnn.required())
Tools.ToolSpecification toolSpecification = // .build());
Tools.ToolSpecification.builder() // }
.functionName(operationName) // final Map<String, Tools.PromptFuncDefinition.Property> params =
.functionDescription(operationDesc) // propsBuilder.build();
.toolPrompt( // List<String> reqProps =
Tools.PromptFuncDefinition.builder() // params.entrySet().stream()
.type("function") // .filter(e -> e.getValue().isRequired())
.function( // .map(Map.Entry::getKey)
Tools.PromptFuncDefinition.PromptFuncSpec // .collect(Collectors.toList());
.builder() //
.name(operationName) // Tools.ToolSpecification toolSpecification =
.description(operationDesc) // Tools.ToolSpecification.builder()
.parameters( // .functionName(operationName)
Tools.PromptFuncDefinition // .functionDescription(operationDesc)
.Parameters.builder() // .toolPrompt(
.type("object") // Tools.PromptFuncDefinition.builder()
.properties(params) // .type("function")
.required(reqProps) // .function(
.build()) // Tools.PromptFuncDefinition.PromptFuncSpec
.build()) // .builder()
.build()) // .name(operationName)
.build(); // .description(operationDesc)
// .parameters(
ReflectionalToolFunction reflectionalToolFunction = // Tools.PromptFuncDefinition
new ReflectionalToolFunction(object, m, methodParams); //
toolSpecification.setToolFunction(reflectionalToolFunction); // .Parameters.builder()
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); // .type("object")
} //
} // .properties(params)
//
// .required(reqProps)
// .build())
// .build())
// .build())
// .build();
//
// ReflectionalToolFunction reflectionalToolFunction =
// new ReflectionalToolFunction(object, m, methodParams);
// toolSpecification.setToolFunction(reflectionalToolFunction);
// toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
// }
// }
/** /**
* Adds a custom role. * Adds a custom role.
@ -1185,32 +1193,32 @@ public class OllamaAPI {
return auth != null; return auth != null;
} }
/** // /**
* Invokes a registered tool function by name and arguments. // * Invokes a registered tool function by name and arguments.
* // *
* @param toolFunctionCallSpec the tool function call specification // * @param toolFunctionCallSpec the tool function call specification
* @return the result of the tool function // * @return the result of the tool function
* @throws ToolInvocationException if the tool is not found or invocation fails // * @throws ToolInvocationException if the tool is not found or invocation fails
*/ // */
private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) // private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec)
throws ToolInvocationException { // throws ToolInvocationException {
try { // try {
String methodName = toolFunctionCallSpec.getName(); // String methodName = toolFunctionCallSpec.getName();
Map<String, Object> arguments = toolFunctionCallSpec.getArguments(); // Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
ToolFunction function = toolRegistry.getToolFunction(methodName); // ToolFunction function = toolRegistry.getToolFunction(methodName);
LOG.debug("Invoking function {} with arguments {}", methodName, arguments); // LOG.debug("Invoking function {} with arguments {}", methodName, arguments);
if (function == null) { // if (function == null) {
throw new ToolNotFoundException( // throw new ToolNotFoundException(
"No such tool: " // "No such tool: "
+ methodName // + methodName
+ ". Please register the tool before invoking it."); // + ". Please register the tool before invoking it.");
} // }
return function.apply(arguments); // return function.apply(arguments);
} catch (Exception e) { // } catch (Exception e) {
throw new ToolInvocationException( // throw new ToolInvocationException(
"Failed to invoke tool: " + toolFunctionCallSpec.getName(), e); // "Failed to invoke tool: " + toolFunctionCallSpec.getName(), e);
} // }
} // }
// /** // /**
// * Initialize metrics collection if enabled. // * Initialize metrics collection if enabled.

View File

@ -29,7 +29,7 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ
private List<OllamaChatMessage> messages = Collections.emptyList(); private List<OllamaChatMessage> messages = Collections.emptyList();
private List<Tools.PromptFuncDefinition> tools; private List<Tools.Tool> tools;
private boolean think; private boolean think;

View File

@ -26,7 +26,7 @@ public class OllamaGenerateRequest extends OllamaCommonRequest implements Ollama
private boolean raw; private boolean raw;
private boolean think; private boolean think;
private boolean useTools; private boolean useTools;
private List<Tools.PromptFuncDefinition> tools; private List<Tools.Tool> tools;
public OllamaGenerateRequest() {} public OllamaGenerateRequest() {}

View File

@ -8,12 +8,14 @@
*/ */
package io.github.ollama4j.models.generate; package io.github.ollama4j.models.generate;
import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.Options; import io.github.ollama4j.utils.Options;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Base64; import java.util.Base64;
import java.util.List;
/** Helper class for creating {@link OllamaGenerateRequest} objects using the builder-pattern. */ /** Helper class for creating {@link OllamaGenerateRequest} objects using the builder-pattern. */
public class OllamaGenerateRequestBuilder { public class OllamaGenerateRequestBuilder {
@ -37,6 +39,11 @@ public class OllamaGenerateRequestBuilder {
return this; return this;
} }
public OllamaGenerateRequestBuilder withTools(List<Tools.Tool> tools) {
request.setTools(tools);
return this;
}
public OllamaGenerateRequestBuilder withModel(String model) { public OllamaGenerateRequestBuilder withModel(String model) {
request.setModel(model); request.setModel(model);
return this; return this;

View File

@ -96,6 +96,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
getRequestBuilderDefault(uri).POST(body.getBodyPublisher()); getRequestBuilderDefault(uri).POST(body.getBodyPublisher());
HttpRequest request = requestBuilder.build(); HttpRequest request = requestBuilder.build();
LOG.debug("Asking model: {}", body); LOG.debug("Asking model: {}", body);
System.out.println("Asking model: " + Utils.toJSON(body));
HttpResponse<InputStream> response = HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
@ -140,7 +141,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
statusCode, statusCode,
responseBuffer); responseBuffer);
if (statusCode != 200) { if (statusCode != 200) {
LOG.error("Status code: " + statusCode); LOG.error("Status code: {}", statusCode);
System.out.println(responseBuffer);
throw new OllamaBaseException(responseBuffer.toString()); throw new OllamaBaseException(responseBuffer.toString());
} }
if (wantedToolsForStream != null && ollamaChatResponseModel != null) { if (wantedToolsForStream != null && ollamaChatResponseModel != null) {

View File

@ -8,29 +8,40 @@
*/ */
package io.github.ollama4j.tools; package io.github.ollama4j.tools;
import java.util.Collection; import io.github.ollama4j.exceptions.ToolNotFoundException;
import java.util.HashMap; import java.util.*;
import java.util.Map;
public class ToolRegistry { public class ToolRegistry {
private final Map<String, Tools.ToolSpecification> tools = new HashMap<>(); private final List<Tools.Tool> tools = new ArrayList<>();
public ToolFunction getToolFunction(String name) { public ToolFunction getToolFunction(String name) throws ToolNotFoundException {
final Tools.ToolSpecification toolSpecification = tools.get(name); for (Tools.Tool tool : tools) {
return toolSpecification != null ? toolSpecification.getToolFunction() : null; if (tool.getToolSpec().getName().equals(name)) {
return tool.getToolFunction();
}
}
throw new ToolNotFoundException(String.format("Tool '%s' not found.", name));
} }
public void addTool(String name, Tools.ToolSpecification specification) { public void addTool(Tools.Tool tool) {
tools.put(name, specification); try {
getToolFunction(tool.getToolSpec().getName());
} catch (ToolNotFoundException e) {
tools.add(tool);
}
} }
public Collection<Tools.ToolSpecification> getRegisteredSpecs() { public void addTools(List<Tools.Tool> tools) {
return tools.values(); for (Tools.Tool tool : tools) {
addTool(tool);
}
} }
/** public List<Tools.Tool> getRegisteredTools() {
* Removes all registered tools from the registry. return tools;
*/ }
/** Removes all registered tools from the registry. */
public void clear() { public void clear() {
tools.clear(); tools.clear();
} }

View File

@ -9,13 +9,10 @@
package io.github.ollama4j.tools; package io.github.ollama4j.tools;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.node.ObjectNode;
import io.github.ollama4j.utils.Utils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
@ -26,40 +23,81 @@ import lombok.NoArgsConstructor;
public class Tools { public class Tools {
@Data @Data
@Builder @Builder
public static class ToolSpecification { @NoArgsConstructor
private String functionName; @AllArgsConstructor
private String functionDescription; public static class Tool {
private PromptFuncDefinition toolPrompt; @JsonProperty("function")
private ToolFunction toolFunction; private ToolSpec toolSpec;
private String type = "function";
@JsonIgnore private ToolFunction toolFunction;
} }
@Data @Data
@JsonIgnoreProperties(ignoreUnknown = true)
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
public static class PromptFuncDefinition { public static class ToolSpec {
private String type;
private PromptFuncSpec function;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class PromptFuncSpec {
private String name; private String name;
private String description; private String description;
private Parameters parameters; private Parameters parameters;
} }
@Data @Data
@Builder
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
public static class Parameters { public static class Parameters {
private String type;
private Map<String, Property> properties; private Map<String, Property> properties;
private List<String> required; private List<String> required = new ArrayList<>();
public static Parameters of(Map<String, Property> properties) {
Parameters params = new Parameters();
params.setProperties(properties);
// Optionally, populate required from properties' required flags
if (properties != null) {
for (Map.Entry<String, Property> entry : properties.entrySet()) {
if (entry.getValue() != null && entry.getValue().isRequired()) {
params.getRequired().add(entry.getKey());
}
}
}
return params;
}
@Override
public String toString() {
ObjectNode node =
com.fasterxml.jackson.databind.json.JsonMapper.builder()
.build()
.createObjectNode();
node.put("type", "object");
if (properties != null) {
ObjectNode propsNode = node.putObject("properties");
for (Map.Entry<String, Property> entry : properties.entrySet()) {
ObjectNode propNode = propsNode.putObject(entry.getKey());
Property prop = entry.getValue();
propNode.put("type", prop.getType());
propNode.put("description", prop.getDescription());
if (prop.getEnumValues() != null) {
propNode.putArray("enum")
.addAll(
prop.getEnumValues().stream()
.map(
com.fasterxml.jackson.databind.node.TextNode
::new)
.collect(java.util.stream.Collectors.toList()));
}
}
}
if (required != null && !required.isEmpty()) {
node.putArray("required")
.addAll(
required.stream()
.map(com.fasterxml.jackson.databind.node.TextNode::new)
.collect(java.util.stream.Collectors.toList()));
}
return node.toPrettyString();
}
} }
@Data @Data
@ -77,64 +115,3 @@ public class Tools {
@JsonIgnore private boolean required; @JsonIgnore private boolean required;
} }
} }
public static class PropsBuilder {
private final Map<String, PromptFuncDefinition.Property> props = new HashMap<>();
public PropsBuilder withProperty(String key, PromptFuncDefinition.Property property) {
props.put(key, property);
return this;
}
public Map<String, PromptFuncDefinition.Property> build() {
return props;
}
}
public static class PromptBuilder {
private final List<PromptFuncDefinition> tools = new ArrayList<>();
private String promptText;
public String build() throws JsonProcessingException {
return "[AVAILABLE_TOOLS] "
+ Utils.getObjectMapper().writeValueAsString(tools)
+ "[/AVAILABLE_TOOLS][INST] "
+ promptText
+ " [/INST]";
}
public PromptBuilder withPrompt(String prompt) throws JsonProcessingException {
promptText = prompt;
return this;
}
public PromptBuilder withToolSpecification(ToolSpecification spec) {
PromptFuncDefinition def = new PromptFuncDefinition();
def.setType("function");
PromptFuncDefinition.PromptFuncSpec functionDetail =
new PromptFuncDefinition.PromptFuncSpec();
functionDetail.setName(spec.getFunctionName());
functionDetail.setDescription(spec.getFunctionDescription());
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
parameters.setType("object");
parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties());
List<String> requiredValues = new ArrayList<>();
for (Map.Entry<String, PromptFuncDefinition.Property> p :
spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) {
if (p.getValue().isRequired()) {
requiredValues.add(p.getKey());
}
}
parameters.setRequired(requiredValues);
functionDetail.setParameters(parameters);
def.setFunction(functionDetail);
tools.add(def);
return this;
}
}
}

View File

@ -25,14 +25,11 @@ import io.github.ollama4j.models.request.CustomModelRequest;
import io.github.ollama4j.models.response.ModelDetail; import io.github.ollama4j.models.response.ModelDetail;
import io.github.ollama4j.models.response.OllamaAsyncResultStreamer; import io.github.ollama4j.models.response.OllamaAsyncResultStreamer;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.tools.ToolFunction;
import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.OptionsBuilder; import io.github.ollama4j.utils.OptionsBuilder;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.Mockito; import org.mockito.Mockito;
@ -93,19 +90,19 @@ class TestMockedAPIs {
} }
} }
@Test // @Test
void testRegisteredTools() { // void testRegisteredTools() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); // OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
doNothing().when(ollamaAPI).registerTools(Collections.emptyList()); // doNothing().when(ollamaAPI).registerTools(Collections.emptyList());
ollamaAPI.registerTools(Collections.emptyList()); // ollamaAPI.registerTools(Collections.emptyList());
verify(ollamaAPI, times(1)).registerTools(Collections.emptyList()); // verify(ollamaAPI, times(1)).registerTools(Collections.emptyList());
//
List<Tools.ToolSpecification> toolSpecifications = new ArrayList<>(); // List<Tools.ToolSpecification> toolSpecifications = new ArrayList<>();
toolSpecifications.add(getSampleToolSpecification()); // toolSpecifications.add(getSampleToolSpecification());
doNothing().when(ollamaAPI).registerTools(toolSpecifications); // doNothing().when(ollamaAPI).registerTools(toolSpecifications);
ollamaAPI.registerTools(toolSpecifications); // ollamaAPI.registerTools(toolSpecifications);
verify(ollamaAPI, times(1)).registerTools(toolSpecifications); // verify(ollamaAPI, times(1)).registerTools(toolSpecifications);
} // }
@Test @Test
void testGetModelDetails() { void testGetModelDetails() {
@ -322,50 +319,63 @@ class TestMockedAPIs {
} }
} }
private static Tools.ToolSpecification getSampleToolSpecification() { // private static Tools.ToolSpecification getSampleToolSpecification() {
return Tools.ToolSpecification.builder() // return Tools.ToolSpecification.builder()
.functionName("current-weather") // .functionName("current-weather")
.functionDescription("Get current weather") // .functionDescription("Get current weather")
.toolFunction( // .toolFunction(
new ToolFunction() { // new ToolFunction() {
@Override // @Override
public Object apply(Map<String, Object> arguments) { // public Object apply(Map<String, Object> arguments) {
String location = arguments.get("city").toString(); // String location = arguments.get("city").toString();
return "Currently " + location + "'s weather is beautiful."; // return "Currently " + location + "'s weather is beautiful.";
} // }
}) // })
.toolPrompt( // .toolPrompt(
Tools.PromptFuncDefinition.builder() // Tools.PromptFuncDefinition.builder()
.type("prompt") // .type("prompt")
.function( // .function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder() // Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-location-weather-info") // .name("get-location-weather-info")
.description("Get location details") // .description("Get location details")
.parameters( // .parameters(
Tools.PromptFuncDefinition.Parameters // Tools.PromptFuncDefinition.Parameters
.builder() // .builder()
.type("object") // .type("object")
.properties( // .properties(
Map.of( // Map.of(
"city", // "city",
Tools // Tools
.PromptFuncDefinition //
.Property // .PromptFuncDefinition
.builder() //
.type( // .Property
"string") //
.description( // .builder()
"The city," // .type(
+ " e.g." //
+ " New Delhi," // "string")
+ " India") //
.required( // .description(
true) //
.build())) // "The city,"
.required(java.util.List.of("city")) //
.build()) // + " e.g."
.build()) //
.build()) // + " New Delhi,"
.build(); //
} // + " India")
//
// .required(
//
// true)
//
// .build()))
//
// .required(java.util.List.of("city"))
// .build())
// .build())
// .build())
// .build();
// }
} }

View File

@ -10,47 +10,43 @@ package io.github.ollama4j.unittests;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
import io.github.ollama4j.tools.ToolFunction;
import io.github.ollama4j.tools.ToolRegistry;
import io.github.ollama4j.tools.Tools;
import java.util.Map;
import org.junit.jupiter.api.Test;
class TestToolRegistry { class TestToolRegistry {
//
@Test // @Test
void testAddAndGetToolFunction() { // void testAddAndGetToolFunction() {
ToolRegistry registry = new ToolRegistry(); // ToolRegistry registry = new ToolRegistry();
ToolFunction fn = args -> "ok:" + args.get("x"); // ToolFunction fn = args -> "ok:" + args.get("x");
//
Tools.ToolSpecification spec = // Tools.ToolSpecification spec =
Tools.ToolSpecification.builder() // Tools.ToolSpecification.builder()
.functionName("test") // .functionName("test")
.functionDescription("desc") // .functionDescription("desc")
.toolFunction(fn) // .toolFunction(fn)
.build(); // .build();
//
registry.addTool("test", spec); // registry.addTool("test", spec);
ToolFunction retrieved = registry.getToolFunction("test"); // ToolFunction retrieved = registry.getToolFunction("test");
assertNotNull(retrieved); // assertNotNull(retrieved);
assertEquals("ok:42", retrieved.apply(Map.of("x", 42))); // assertEquals("ok:42", retrieved.apply(Map.of("x", 42)));
} // }
//
@Test // @Test
void testGetUnknownReturnsNull() { // void testGetUnknownReturnsNull() {
ToolRegistry registry = new ToolRegistry(); // ToolRegistry registry = new ToolRegistry();
assertNull(registry.getToolFunction("nope")); // assertNull(registry.getToolFunction("nope"));
} // }
//
@Test // @Test
void testClearRemovesAll() { // void testClearRemovesAll() {
ToolRegistry registry = new ToolRegistry(); // ToolRegistry registry = new ToolRegistry();
registry.addTool("a", Tools.ToolSpecification.builder().toolFunction(args -> 1).build()); // registry.addTool("a", Tools.ToolSpecification.builder().toolFunction(args ->
registry.addTool("b", Tools.ToolSpecification.builder().toolFunction(args -> 2).build()); // 1).build());
assertFalse(registry.getRegisteredSpecs().isEmpty()); // registry.addTool("b", Tools.ToolSpecification.builder().toolFunction(args ->
registry.clear(); // 2).build());
assertTrue(registry.getRegisteredSpecs().isEmpty()); // assertFalse(registry.getRegisteredSpecs().isEmpty());
assertNull(registry.getToolFunction("a")); // registry.clear();
assertNull(registry.getToolFunction("b")); // assertTrue(registry.getRegisteredSpecs().isEmpty());
} // assertNull(registry.getToolFunction("a"));
// assertNull(registry.getToolFunction("b"));
// }
} }

View File

@ -8,68 +8,60 @@
*/ */
package io.github.ollama4j.unittests; package io.github.ollama4j.unittests;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.tools.Tools;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;
class TestToolsPromptBuilder { class TestToolsPromptBuilder {
//
@Test // @Test
void testPromptBuilderIncludesToolsAndPrompt() throws JsonProcessingException { // void testPromptBuilderIncludesToolsAndPrompt() throws JsonProcessingException {
Tools.PromptFuncDefinition.Property cityProp = // Tools.PromptFuncDefinition.Property cityProp =
Tools.PromptFuncDefinition.Property.builder() // Tools.PromptFuncDefinition.Property.builder()
.type("string") // .type("string")
.description("city name") // .description("city name")
.required(true) // .required(true)
.build(); // .build();
//
Tools.PromptFuncDefinition.Property unitsProp = // Tools.PromptFuncDefinition.Property unitsProp =
Tools.PromptFuncDefinition.Property.builder() // Tools.PromptFuncDefinition.Property.builder()
.type("string") // .type("string")
.description("units") // .description("units")
.enumValues(List.of("metric", "imperial")) // .enumValues(List.of("metric", "imperial"))
.required(false) // .required(false)
.build(); // .build();
//
Tools.PromptFuncDefinition.Parameters params = // Tools.PromptFuncDefinition.Parameters params =
Tools.PromptFuncDefinition.Parameters.builder() // Tools.PromptFuncDefinition.Parameters.builder()
.type("object") // .type("object")
.properties(Map.of("city", cityProp, "units", unitsProp)) // .properties(Map.of("city", cityProp, "units", unitsProp))
.build(); // .build();
//
Tools.PromptFuncDefinition.PromptFuncSpec spec = // Tools.PromptFuncDefinition.PromptFuncSpec spec =
Tools.PromptFuncDefinition.PromptFuncSpec.builder() // Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("getWeather") // .name("getWeather")
.description("Get weather for a city") // .description("Get weather for a city")
.parameters(params) // .parameters(params)
.build(); // .build();
//
Tools.PromptFuncDefinition def = // Tools.PromptFuncDefinition def =
Tools.PromptFuncDefinition.builder().type("function").function(spec).build(); // Tools.PromptFuncDefinition.builder().type("function").function(spec).build();
//
Tools.ToolSpecification toolSpec = // Tools.ToolSpecification toolSpec =
Tools.ToolSpecification.builder() // Tools.ToolSpecification.builder()
.functionName("getWeather") // .functionName("getWeather")
.functionDescription("Get weather for a city") // .functionDescription("Get weather for a city")
.toolPrompt(def) // .toolPrompt(def)
.build(); // .build();
//
Tools.PromptBuilder pb = // Tools.PromptBuilder pb =
new Tools.PromptBuilder() // new Tools.PromptBuilder()
.withToolSpecification(toolSpec) // .withToolSpecification(toolSpec)
.withPrompt("Tell me the weather."); // .withPrompt("Tell me the weather.");
//
String built = pb.build(); // String built = pb.build();
assertTrue(built.contains("[AVAILABLE_TOOLS]")); // assertTrue(built.contains("[AVAILABLE_TOOLS]"));
assertTrue(built.contains("[/AVAILABLE_TOOLS]")); // assertTrue(built.contains("[/AVAILABLE_TOOLS]"));
assertTrue(built.contains("[INST]")); // assertTrue(built.contains("[INST]"));
assertTrue(built.contains("Tell me the weather.")); // assertTrue(built.contains("Tell me the weather."));
assertTrue(built.contains("\"name\":\"getWeather\"")); // assertTrue(built.contains("\"name\":\"getWeather\""));
assertTrue(built.contains("\"required\":[\"city\"]")); // assertTrue(built.contains("\"required\":[\"city\"]"));
assertTrue(built.contains("\"enum\":[\"metric\",\"imperial\"]")); // assertTrue(built.contains("\"enum\":[\"metric\",\"imperial\"]"));
} // }
} }