diff --git a/docs/docs/apis-generate/generate-with-tools.md b/docs/docs/apis-generate/generate-with-tools.md index db50c88..a40969f 100644 --- a/docs/docs/apis-generate/generate-with-tools.md +++ b/docs/docs/apis-generate/generate-with-tools.md @@ -521,6 +521,23 @@ public class MyOllamaService{ } ``` +Or, if one needs to provide an object instance directly: +```java +public class MyOllamaService{ + + public void chatWithAnnotatedTool(){ + ollamaAPI.registerAnnotatedTools(new BackendService()); + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Compute the most important constant in the world using 5 digits") + .build(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + } + +} +``` + The request should be the following: ```json @@ -622,4 +639,4 @@ public String getCurrentFuelPrice(String location, String fuelType) { } ``` -Updating async/chat APIs with support for tool-based generation. \ No newline at end of file +Updating async/chat APIs with support for tool-based generation. diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 29cb467..cbde59e 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -826,88 +826,89 @@ public class OllamaAPI { toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); } + public void registerAnnotatedTools() { - Class callerClass = null; try { - callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName()); - } catch (ClassNotFoundException e) { + Class callerClass = null; + try { + callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName()); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + + OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class); + if (ollamaToolServiceAnnotation == null) { + throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class); + } + + Class[] providers = ollamaToolServiceAnnotation.providers(); + for (Class provider : providers) { + registerAnnotatedTools(provider.getDeclaredConstructor().newInstance()); + } + } catch (InstantiationException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { throw new RuntimeException(e); } + } - OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class); - if(ollamaToolServiceAnnotation == null) { - throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class); - } + public void registerAnnotatedTools(Object object) { + Class objectClass = object.getClass(); + Method[] methods = objectClass.getMethods(); + for(Method m : methods) { + ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class); + if(toolSpec == null){ + continue; + } + String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName(); + String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName; - Class[] providers = ollamaToolServiceAnnotation.providers(); - - for(Class provider : providers){ - Method[] methods = provider.getMethods(); - for(Method m : methods) { - ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class); - if(toolSpec == null){ + final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder(); + LinkedHashMap methodParams = new LinkedHashMap<>(); + for (Parameter parameter : m.getParameters()) { + final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class); + String propType = parameter.getType().getTypeName(); + if(toolPropertyAnn == null) { + methodParams.put(parameter.getName(),null); continue; } - String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName(); - String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName; - - final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder(); - LinkedHashMap methodParams = new LinkedHashMap<>(); - for (Parameter parameter : m.getParameters()) { - final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class); - String propType = parameter.getType().getTypeName(); - if(toolPropertyAnn == null) { - methodParams.put(parameter.getName(),null); - continue; - } - String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName(); - methodParams.put(propName,propType); - propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder() - .type(propType) - .description(toolPropertyAnn.desc()) - .required(toolPropertyAnn.required()) - .build()); - } - final Map params = propsBuilder.build(); - List reqProps = params.entrySet().stream() - .filter(e -> e.getValue().isRequired()) - .map(Map.Entry::getKey) - .collect(Collectors.toList()); - - Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder() - .functionName(operationName) - .functionDescription(operationDesc) - .toolPrompt( - Tools.PromptFuncDefinition.builder().type("function").function( - Tools.PromptFuncDefinition.PromptFuncSpec.builder() - .name(operationName) - .description(operationDesc) - .parameters( - Tools.PromptFuncDefinition.Parameters.builder() - .type("object") - .properties( - params - ) - .required(reqProps) - .build() - ).build() - ).build() - ) - .build(); - - try { - ReflectionalToolFunction reflectionalToolFunction = - new ReflectionalToolFunction(provider.getDeclaredConstructor().newInstance() - ,m - ,methodParams); - - toolSpecification.setToolFunction(reflectionalToolFunction); - toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification); - } catch (InstantiationException | IllegalAccessException | InvocationTargetException | - NoSuchMethodException e) { - throw new RuntimeException(e); - } + String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName(); + methodParams.put(propName,propType); + propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder() + .type(propType) + .description(toolPropertyAnn.desc()) + .required(toolPropertyAnn.required()) + .build()); } + final Map params = propsBuilder.build(); + List reqProps = params.entrySet().stream() + .filter(e -> e.getValue().isRequired()) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + + Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder() + .functionName(operationName) + .functionDescription(operationDesc) + .toolPrompt( + Tools.PromptFuncDefinition.builder().type("function").function( + Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name(operationName) + .description(operationDesc) + .parameters( + Tools.PromptFuncDefinition.Parameters.builder() + .type("object") + .properties( + params + ) + .required(reqProps) + .build() + ).build() + ).build() + ) + .build(); + + ReflectionalToolFunction reflectionalToolFunction = + new ReflectionalToolFunction(object, m, methodParams); + toolSpecification.setToolFunction(reflectionalToolFunction); + toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification); } } diff --git a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java index 9ddfcd5..835fa76 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java @@ -338,7 +338,7 @@ class TestRealAPIs { ollamaAPI.setVerbose(true); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); - ollamaAPI.registerAnnotatedTools(); + ollamaAPI.registerAnnotatedTools(new AnnotatedTool()); OllamaChatRequest requestModel = builder .withMessage(OllamaChatMessageRole.USER,