diff --git a/docs/docs/apis-generate/generate-with-tools.md b/docs/docs/apis-generate/generate-with-tools.md index e0e5794..db50c88 100644 --- a/docs/docs/apis-generate/generate-with-tools.md +++ b/docs/docs/apis-generate/generate-with-tools.md @@ -464,21 +464,155 @@ A typical final response of the above could be: This tool calling can also be done using the streaming API. -### Potential Improvements +### Using Annotation based Tool Registration -Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool -registration. For example: +Instead of explicitly registering each tool, ollama4j supports declarative tool specification and registration via java +Annotations and reflection calling. + +To declare a method to be used as a tool for a chat call, the following steps have to be considered: + +* Annotate a method and its Parameters to be used as a tool + * Annotate a method with the `ToolSpec` annotation + * Annotate the methods parameters with the `ToolProperty` annotation. Only the following datatypes are supported for now: + * `java.lang.String` + * `java.lang.Integer` + * `java.lang.Boolean` + * `java.math.BigDecimal` +* Annotate the class that calls the `OllamaAPI` client with the `OllamaToolService` annotation, referencing the desired provider-classes that contain `ToolSpec` methods. +* Before calling the `OllamaAPI` chat request, call the method `OllamaAPI.registerAnnotatedTools()` method to add tools to the chat. + +#### Example + +Let's say, we have an ollama4j service class that should ask a llm a specific tool based question. + +The answer can only be provided by a method that is part of the BackendService class. To provide a tool for the llm, the following annotations can be used: ```java +public class BackendService{ + + public BackendService(){} -@ToolSpec(name = "current-fuel-price", desc = "Get current fuel price") -public String getCurrentFuelPrice(Map arguments) { - String location = arguments.get("location").toString(); - String fuelType = arguments.get("fuelType").toString(); - return "Current price of " + fuelType + " in " + location + " is Rs.103/L"; + @ToolSpec(desc = "Computes the most important constant all around the globe!") + public String computeMkeConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){ + return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString(); + } } ``` +The caller API can then be written as: +```java +import io.github.ollama4j.tools.annotations.OllamaToolService; + +@OllamaToolService(providers = BackendService.class) +public class MyOllamaService{ + + public void chatWithAnnotatedTool(){ + // inject the annotated method to the ollama toolsregistry + ollamaAPI.registerAnnotatedTools(); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Compute the most important constant in the world using 5 digits") + .build(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + } + +} +``` + +The request should be the following: + +```json +{ + "model" : "llama3.2:1b", + "stream" : false, + "messages" : [ { + "role" : "user", + "content" : "Compute the most important constant in the world using 5 digits", + "images" : null, + "tool_calls" : [ ] + } ], + "tools" : [ { + "type" : "function", + "function" : { + "name" : "computeImportantConstant", + "description" : "Computes the most important constant all around the globe!", + "parameters" : { + "type" : "object", + "properties" : { + "noOfDigits" : { + "type" : "java.lang.Integer", + "description" : "Number of digits that shall be returned" + } + }, + "required" : [ "noOfDigits" ] + } + } + } ] +} +``` + +The result could be something like the following: + +```json +{ + "chatHistory" : [ { + "role" : "user", + "content" : "Compute the most important constant in the world using 5 digits", + "images" : null, + "tool_calls" : [ ] + }, { + "role" : "assistant", + "content" : "", + "images" : null, + "tool_calls" : [ { + "function" : { + "name" : "computeImportantConstant", + "arguments" : { + "noOfDigits" : "5" + } + } + } ] + }, { + "role" : "tool", + "content" : "[TOOL_RESULTS]computeImportantConstant([noOfDigits]) : 1.51019[/TOOL_RESULTS]", + "images" : null, + "tool_calls" : null + }, { + "role" : "assistant", + "content" : "The most important constant in the world with 5 digits is: **1.51019**", + "images" : null, + "tool_calls" : null + } ], + "responseModel" : { + "model" : "llama3.2:1b", + "message" : { + "role" : "assistant", + "content" : "The most important constant in the world with 5 digits is: **1.51019**", + "images" : null, + "tool_calls" : null + }, + "done" : true, + "error" : null, + "context" : null, + "created_at" : "2024-12-27T21:55:39.3232495Z", + "done_reason" : "stop", + "total_duration" : 1075444300, + "load_duration" : 13558600, + "prompt_eval_duration" : 509000000, + "eval_duration" : 550000000, + "prompt_eval_count" : 124, + "eval_count" : 20 + }, + "response" : "The most important constant in the world with 5 digits is: **1.51019**", + "responseTime" : 1075444300, + "httpStatusCode" : 200 +} +``` + +### Potential Improvements + Instead of passing a map of args `Map arguments` to the tool functions, we could support passing specific args separately with their data types. For example: diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 90dcd35..29cb467 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -15,11 +15,17 @@ import io.github.ollama4j.models.ps.ModelsProcessResponse; import io.github.ollama4j.models.request.*; import io.github.ollama4j.models.response.*; import io.github.ollama4j.tools.*; +import io.github.ollama4j.tools.annotations.OllamaToolService; +import io.github.ollama4j.tools.annotations.ToolProperty; +import io.github.ollama4j.tools.annotations.ToolSpec; import io.github.ollama4j.utils.Options; import io.github.ollama4j.utils.Utils; import lombok.Setter; import java.io.*; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; import java.net.URI; import java.net.URISyntaxException; import java.net.http.HttpClient; @@ -603,6 +609,15 @@ public class OllamaAPI { OllamaToolsResult toolResult = new OllamaToolsResult(); Map toolResults = new HashMap<>(); + if(!prompt.startsWith("[AVAILABLE_TOOLS]")){ + final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder(); + for(Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) { + promptBuilder.withToolSpecification(spec); + } + promptBuilder.withPrompt(prompt); + prompt = promptBuilder.build(); + } + OllamaResult result = generate(model, prompt, raw, options, null); toolResult.setModelResult(result); @@ -811,6 +826,92 @@ public class OllamaAPI { toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); } + public void registerAnnotatedTools() { + Class callerClass = null; + try { + callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName()); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + + OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class); + if(ollamaToolServiceAnnotation == null) { + throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class); + } + + Class[] providers = ollamaToolServiceAnnotation.providers(); + + for(Class provider : providers){ + Method[] methods = provider.getMethods(); + for(Method m : methods) { + ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class); + if(toolSpec == null){ + continue; + } + String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName(); + String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName; + + final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder(); + LinkedHashMap methodParams = new LinkedHashMap<>(); + for (Parameter parameter : m.getParameters()) { + final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class); + String propType = parameter.getType().getTypeName(); + if(toolPropertyAnn == null) { + methodParams.put(parameter.getName(),null); + continue; + } + String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName(); + methodParams.put(propName,propType); + propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder() + .type(propType) + .description(toolPropertyAnn.desc()) + .required(toolPropertyAnn.required()) + .build()); + } + final Map params = propsBuilder.build(); + List reqProps = params.entrySet().stream() + .filter(e -> e.getValue().isRequired()) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + + Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder() + .functionName(operationName) + .functionDescription(operationDesc) + .toolPrompt( + Tools.PromptFuncDefinition.builder().type("function").function( + Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name(operationName) + .description(operationDesc) + .parameters( + Tools.PromptFuncDefinition.Parameters.builder() + .type("object") + .properties( + params + ) + .required(reqProps) + .build() + ).build() + ).build() + ) + .build(); + + try { + ReflectionalToolFunction reflectionalToolFunction = + new ReflectionalToolFunction(provider.getDeclaredConstructor().newInstance() + ,m + ,methodParams); + + toolSpecification.setToolFunction(reflectionalToolFunction); + toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException | + NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + } + + } + /** * Adds a custom role. * diff --git a/src/main/java/io/github/ollama4j/tools/ReflectionalToolFunction.java b/src/main/java/io/github/ollama4j/tools/ReflectionalToolFunction.java new file mode 100644 index 0000000..66d078b --- /dev/null +++ b/src/main/java/io/github/ollama4j/tools/ReflectionalToolFunction.java @@ -0,0 +1,54 @@ +package io.github.ollama4j.tools; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; + +import java.lang.reflect.Method; +import java.math.BigDecimal; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Specification of a {@link ToolFunction} that provides the implementation via java reflection calling. + */ +@Setter +@Getter +@AllArgsConstructor +public class ReflectionalToolFunction implements ToolFunction{ + + private Object functionHolder; + private Method function; + private LinkedHashMap propertyDefinition; + + @Override + public Object apply(Map arguments) { + LinkedHashMap argumentsCopy = new LinkedHashMap<>(this.propertyDefinition); + for (Map.Entry param : this.propertyDefinition.entrySet()){ + argumentsCopy.replace(param.getKey(),typeCast(arguments.get(param.getKey()),param.getValue())); + } + try { + return function.invoke(functionHolder, argumentsCopy.values().toArray()); + } catch (Exception e) { + throw new RuntimeException("Failed to invoke tool: " + function.getName(), e); + } + } + + private Object typeCast(Object inputValue, String className) { + if(className == null || inputValue == null) { + return null; + } + String inputValueString = inputValue.toString(); + switch (className) { + case "java.lang.Integer": + return Integer.parseInt(inputValueString); + case "java.lang.Boolean": + return Boolean.valueOf(inputValueString); + case "java.math.BigDecimal": + return new BigDecimal(inputValueString); + default: + return inputValueString; + } + } + +} diff --git a/src/main/java/io/github/ollama4j/tools/annotations/OllamaToolService.java b/src/main/java/io/github/ollama4j/tools/annotations/OllamaToolService.java new file mode 100644 index 0000000..5118430 --- /dev/null +++ b/src/main/java/io/github/ollama4j/tools/annotations/OllamaToolService.java @@ -0,0 +1,23 @@ +package io.github.ollama4j.tools.annotations; + +import io.github.ollama4j.OllamaAPI; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotates a class that calls {@link io.github.ollama4j.OllamaAPI} such that the Method + * {@link OllamaAPI#registerAnnotatedTools()} can be used to auto-register all provided classes (resp. all + * contained Methods of the provider classes annotated with {@link ToolSpec}). + */ +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.RUNTIME) +public @interface OllamaToolService { + + /** + * @return Classes with no-arg constructor that will be used for tool-registration. + */ + Class[] providers(); +} diff --git a/src/main/java/io/github/ollama4j/tools/annotations/ToolProperty.java b/src/main/java/io/github/ollama4j/tools/annotations/ToolProperty.java new file mode 100644 index 0000000..28d9acc --- /dev/null +++ b/src/main/java/io/github/ollama4j/tools/annotations/ToolProperty.java @@ -0,0 +1,32 @@ +package io.github.ollama4j.tools.annotations; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotates a Method Parameter in a {@link ToolSpec} annotated Method. A parameter annotated with this annotation will + * be part of the tool description that is sent to the llm for tool-calling. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.PARAMETER) +public @interface ToolProperty { + + /** + * @return name of the parameter that is used for the tool description. Has to be set as depending on the caller, + * method name backtracking is not possible with reflection. + */ + String name(); + + /** + * @return a detailed description of the parameter. This is used by the llm called to specify, which property has + * to be set by the llm and how this should be filled. + */ + String desc(); + + /** + * @return tells the llm that it has to set a value for this property. + */ + boolean required() default true; +} diff --git a/src/main/java/io/github/ollama4j/tools/annotations/ToolSpec.java b/src/main/java/io/github/ollama4j/tools/annotations/ToolSpec.java new file mode 100644 index 0000000..7f99768 --- /dev/null +++ b/src/main/java/io/github/ollama4j/tools/annotations/ToolSpec.java @@ -0,0 +1,28 @@ +package io.github.ollama4j.tools.annotations; + +import io.github.ollama4j.OllamaAPI; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotates Methods of classes that should be registered as tools by {@link OllamaAPI#registerAnnotatedTools()} + * automatically. + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface ToolSpec { + + /** + * @return tool-name that the method should be used as. Defaults to the methods name. + */ + String name() default ""; + + /** + * @return a detailed description of the method that can be interpreted by the llm, whether it should call the tool + * or not. + */ + String desc(); +} diff --git a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java index 668a5dc..9ddfcd5 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java @@ -7,8 +7,11 @@ import io.github.ollama4j.models.response.ModelDetail; import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; +import io.github.ollama4j.samples.AnnotatedTool; +import io.github.ollama4j.tools.OllamaToolCallsFunction; import io.github.ollama4j.tools.ToolFunction; import io.github.ollama4j.tools.Tools; +import io.github.ollama4j.tools.annotations.OllamaToolService; import io.github.ollama4j.utils.OptionsBuilder; import lombok.Data; import org.junit.jupiter.api.BeforeEach; @@ -27,6 +30,8 @@ import java.util.*; import static org.junit.jupiter.api.Assertions.*; +@OllamaToolService(providers = {AnnotatedTool.class} +) class TestRealAPIs { private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class); @@ -229,7 +234,7 @@ class TestRealAPIs { @Test @Order(3) - void testChatWithTools() { + void testChatWithExplicitToolDefinition() { testEndpointReachability(); try { ollamaAPI.setVerbose(true); @@ -275,9 +280,10 @@ class TestRealAPIs { assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName()); List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); assertEquals(1, toolCalls.size()); - assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName()); - assertEquals(1, toolCalls.get(0).getFunction().getArguments().size()); - Object employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name"); + OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); + assertEquals("get-employee-details", function.getName()); + assertEquals(1, function.getArguments().size()); + Object employeeName = function.getArguments().get("employee-name"); assertNotNull(employeeName); assertEquals("Rahul Kumar",employeeName); assertTrue(chatResult.getChatHistory().size()>2); @@ -288,6 +294,82 @@ class TestRealAPIs { } } + @Test + @Order(3) + void testChatWithAnnotatedToolsAndSingleParam() { + testEndpointReachability(); + try { + ollamaAPI.setVerbose(true); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); + + ollamaAPI.registerAnnotatedTools(); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Compute the most important constant in the world using 5 digits") + .build(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName()); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); + assertEquals(1, toolCalls.size()); + OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); + assertEquals("computeImportantConstant", function.getName()); + assertEquals(1, function.getArguments().size()); + Object noOfDigits = function.getArguments().get("noOfDigits"); + assertNotNull(noOfDigits); + assertEquals("5",noOfDigits); + assertTrue(chatResult.getChatHistory().size()>2); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } + } + + @Test + @Order(3) + void testChatWithAnnotatedToolsAndMultipleParams() { + testEndpointReachability(); + try { + ollamaAPI.setVerbose(true); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); + + ollamaAPI.registerAnnotatedTools(); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Greet Pedro with a lot of hearts and respond to me, " + + "and state how many emojis have been in your greeting") + .build(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName()); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); + assertEquals(1, toolCalls.size()); + OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); + assertEquals("sayHello", function.getName()); + assertEquals(2, function.getArguments().size()); + Object name = function.getArguments().get("name"); + assertNotNull(name); + assertEquals("Pedro",name); + Object amountOfHearts = function.getArguments().get("amountOfHearts"); + assertNotNull(amountOfHearts); + assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1); + assertTrue(chatResult.getChatHistory().size()>2); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } + } + @Test @Order(3) void testChatWithToolsAndStream() { diff --git a/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java b/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java new file mode 100644 index 0000000..8202e77 --- /dev/null +++ b/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java @@ -0,0 +1,21 @@ +package io.github.ollama4j.samples; + +import io.github.ollama4j.tools.annotations.ToolProperty; +import io.github.ollama4j.tools.annotations.ToolSpec; + +import java.math.BigDecimal; + +public class AnnotatedTool { + + @ToolSpec(desc = "Computes the most important constant all around the globe!") + public String computeImportantConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){ + return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString(); + } + + @ToolSpec(desc = "Says hello to a friend!") + public String sayHello(@ToolProperty(name = "name",desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name="amountOfHearts",desc = "amount of heart emojis that should be used", required = false) Integer amountOfHearts) { + String hearts = amountOfHearts!=null ? "♡".repeat(amountOfHearts) : ""; + return "Hello " + name +" ("+someRandomProperty+") " + hearts; + } + +}