Adds first approach to annotation based tool callings using basic java reflection

This commit is contained in:
Markus Klenke
2024-12-27 22:20:34 +01:00
parent 35f5f34196
commit 5e6971cc4a
7 changed files with 304 additions and 4 deletions

View File

@@ -15,11 +15,17 @@ import io.github.ollama4j.models.ps.ModelsProcessResponse;
import io.github.ollama4j.models.request.*;
import io.github.ollama4j.models.response.*;
import io.github.ollama4j.tools.*;
import io.github.ollama4j.tools.annotations.OllamaToolService;
import io.github.ollama4j.tools.annotations.ToolProperty;
import io.github.ollama4j.tools.annotations.ToolSpec;
import io.github.ollama4j.utils.Options;
import io.github.ollama4j.utils.Utils;
import lombok.Setter;
import java.io.*;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.http.HttpClient;
@@ -603,6 +609,15 @@ public class OllamaAPI {
OllamaToolsResult toolResult = new OllamaToolsResult();
Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
if(!prompt.startsWith("[AVAILABLE_TOOLS]")){
final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder();
for(Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) {
promptBuilder.withToolSpecification(spec);
}
promptBuilder.withPrompt(prompt);
prompt = promptBuilder.build();
}
OllamaResult result = generate(model, prompt, raw, options, null);
toolResult.setModelResult(result);
@@ -811,6 +826,94 @@ public class OllamaAPI {
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
}
public void registerAnnotatedTools() {
Class<?> callerClass = null;
try {
callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class);
if(ollamaToolServiceAnnotation == null) {
throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class);
}
Class<?>[] providers = ollamaToolServiceAnnotation.providers();
for(Class<?> provider : providers){
System.err.println("Provider: " + provider.getName());
Method[] methods = provider.getMethods();
for(Method m : methods) {
ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
if(toolSpec == null){
continue;
}
String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
System.err.println("Method: " + operationName);
final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
LinkedHashMap<String,String> methodParams = new LinkedHashMap<>();
for (Parameter parameter : m.getParameters()) {
final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class);
String propType = parameter.getType().getTypeName();
if(toolPropertyAnn == null) {
methodParams.put(parameter.getName(),null);
continue;
}
String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
methodParams.put(propName,propType);
propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder()
.type(propType)
.description(toolPropertyAnn.desc())
.required(toolPropertyAnn.required())
.build());
}
final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
List<String> reqProps = params.entrySet().stream()
.filter(e -> e.getValue().isRequired())
.map(Map.Entry::getKey)
.collect(Collectors.toList());
Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder()
.functionName(operationName)
.functionDescription(operationDesc)
.toolPrompt(
Tools.PromptFuncDefinition.builder().type("function").function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name(operationName)
.description(operationDesc)
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
params
)
.required(reqProps)
.build()
).build()
).build()
)
.build();
try {
ReflectionalToolFunction reflectionalToolFunction =
new ReflectionalToolFunction(provider.getDeclaredConstructor().newInstance()
,m
,methodParams);
toolSpecification.setToolFunction(reflectionalToolFunction);
toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification);
} catch (InstantiationException | IllegalAccessException | InvocationTargetException |
NoSuchMethodException e) {
throw new RuntimeException(e);
}
}
}
}
/**
* Adds a custom role.
*

View File

@@ -0,0 +1,49 @@
package io.github.ollama4j.tools;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import java.lang.reflect.Method;
import java.util.LinkedHashMap;
import java.util.Map;
@Setter
@Getter
@AllArgsConstructor
public class ReflectionalToolFunction implements ToolFunction{
private Object functionHolder;
private Method function;
private LinkedHashMap<String,String> propertyDefinition;
@Override
public Object apply(Map<String, Object> arguments) {
LinkedHashMap<String, Object> argumentsCopy = new LinkedHashMap<>(this.propertyDefinition);
for (Map.Entry<String,String> param : this.propertyDefinition.entrySet()){
argumentsCopy.replace(param.getKey(),typeCast(arguments.get(param.getKey()),param.getValue()));
}
try {
return function.invoke(functionHolder, argumentsCopy.values().toArray());
} catch (Exception e) {
throw new RuntimeException("Failed to invoke tool: " + function.getName(), e);
}
}
private Object typeCast(Object inputValue, String className) {
if(className == null || inputValue == null) {
return null;
}
String inputValueString = inputValue.toString();
if("java.lang.Integer".equals(className)){
return Integer.parseInt(inputValueString);
}
if("java.lang.Boolean".equals(className)){
return Boolean.valueOf(inputValueString);
}
else {
return inputValueString;
}
}
}

View File

@@ -0,0 +1,13 @@
package io.github.ollama4j.tools.annotations;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface OllamaToolService {
Class<?>[] providers();
}

View File

@@ -0,0 +1,17 @@
package io.github.ollama4j.tools.annotations;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.PARAMETER)
public @interface ToolProperty {
String name();
String desc();
boolean required() default true;
}

View File

@@ -0,0 +1,15 @@
package io.github.ollama4j.tools.annotations;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface ToolSpec {
String name() default "";
String desc();
}