Add support for registering object instances instead of only through the @OllamaToolService annotation

This commit is contained in:
Sebastiaan de Schaetzen 2025-01-24 13:38:47 +01:00
parent f27bea11d5
commit b2b3febdaa
3 changed files with 93 additions and 75 deletions

View File

@ -521,6 +521,23 @@ public class MyOllamaService{
} }
``` ```
Or, if one needs to provide an object instance directly:
```java
public class MyOllamaService{
public void chatWithAnnotatedTool(){
ollamaAPI.registerAnnotatedTools(new BackendService());
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
"Compute the most important constant in the world using 5 digits")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
}
}
```
The request should be the following: The request should be the following:
```json ```json

View File

@ -826,88 +826,89 @@ public class OllamaAPI {
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
} }
public void registerAnnotatedTools() { public void registerAnnotatedTools() {
Class<?> callerClass = null;
try { try {
callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName()); Class<?> callerClass = null;
} catch (ClassNotFoundException e) { try {
callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class);
if (ollamaToolServiceAnnotation == null) {
throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class);
}
Class<?>[] providers = ollamaToolServiceAnnotation.providers();
for (Class<?> provider : providers) {
registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
}
} catch (InstantiationException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
}
OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class); public void registerAnnotatedTools(Object object) {
if(ollamaToolServiceAnnotation == null) { Class<?> objectClass = object.getClass();
throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class); Method[] methods = objectClass.getMethods();
} for(Method m : methods) {
ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
if(toolSpec == null){
continue;
}
String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
Class<?>[] providers = ollamaToolServiceAnnotation.providers(); final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
LinkedHashMap<String,String> methodParams = new LinkedHashMap<>();
for(Class<?> provider : providers){ for (Parameter parameter : m.getParameters()) {
Method[] methods = provider.getMethods(); final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class);
for(Method m : methods) { String propType = parameter.getType().getTypeName();
ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class); if(toolPropertyAnn == null) {
if(toolSpec == null){ methodParams.put(parameter.getName(),null);
continue; continue;
} }
String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName(); String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName; methodParams.put(propName,propType);
propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder()
final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder(); .type(propType)
LinkedHashMap<String,String> methodParams = new LinkedHashMap<>(); .description(toolPropertyAnn.desc())
for (Parameter parameter : m.getParameters()) { .required(toolPropertyAnn.required())
final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class); .build());
String propType = parameter.getType().getTypeName();
if(toolPropertyAnn == null) {
methodParams.put(parameter.getName(),null);
continue;
}
String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
methodParams.put(propName,propType);
propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder()
.type(propType)
.description(toolPropertyAnn.desc())
.required(toolPropertyAnn.required())
.build());
}
final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
List<String> reqProps = params.entrySet().stream()
.filter(e -> e.getValue().isRequired())
.map(Map.Entry::getKey)
.collect(Collectors.toList());
Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder()
.functionName(operationName)
.functionDescription(operationDesc)
.toolPrompt(
Tools.PromptFuncDefinition.builder().type("function").function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name(operationName)
.description(operationDesc)
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
params
)
.required(reqProps)
.build()
).build()
).build()
)
.build();
try {
ReflectionalToolFunction reflectionalToolFunction =
new ReflectionalToolFunction(provider.getDeclaredConstructor().newInstance()
,m
,methodParams);
toolSpecification.setToolFunction(reflectionalToolFunction);
toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification);
} catch (InstantiationException | IllegalAccessException | InvocationTargetException |
NoSuchMethodException e) {
throw new RuntimeException(e);
}
} }
final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
List<String> reqProps = params.entrySet().stream()
.filter(e -> e.getValue().isRequired())
.map(Map.Entry::getKey)
.collect(Collectors.toList());
Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder()
.functionName(operationName)
.functionDescription(operationDesc)
.toolPrompt(
Tools.PromptFuncDefinition.builder().type("function").function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name(operationName)
.description(operationDesc)
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
params
)
.required(reqProps)
.build()
).build()
).build()
)
.build();
ReflectionalToolFunction reflectionalToolFunction =
new ReflectionalToolFunction(object, m, methodParams);
toolSpecification.setToolFunction(reflectionalToolFunction);
toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification);
} }
} }

View File

@ -338,7 +338,7 @@ class TestRealAPIs {
ollamaAPI.setVerbose(true); ollamaAPI.setVerbose(true);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
ollamaAPI.registerAnnotatedTools(); ollamaAPI.registerAnnotatedTools(new AnnotatedTool());
OllamaChatRequest requestModel = builder OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, .withMessage(OllamaChatMessageRole.USER,