mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 11:57:12 +02:00
Merge pull request #85 from AgentSchmecker/feature/annotationBasedTools
Feature/annotation based tools
This commit is contained in:
commit
a494053263
@ -464,21 +464,155 @@ A typical final response of the above could be:
|
|||||||
|
|
||||||
This tool calling can also be done using the streaming API.
|
This tool calling can also be done using the streaming API.
|
||||||
|
|
||||||
### Potential Improvements
|
### Using Annotation based Tool Registration
|
||||||
|
|
||||||
Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool
|
Instead of explicitly registering each tool, ollama4j supports declarative tool specification and registration via java
|
||||||
registration. For example:
|
Annotations and reflection calling.
|
||||||
|
|
||||||
|
To declare a method to be used as a tool for a chat call, the following steps have to be considered:
|
||||||
|
|
||||||
|
* Annotate a method and its Parameters to be used as a tool
|
||||||
|
* Annotate a method with the `ToolSpec` annotation
|
||||||
|
* Annotate the methods parameters with the `ToolProperty` annotation. Only the following datatypes are supported for now:
|
||||||
|
* `java.lang.String`
|
||||||
|
* `java.lang.Integer`
|
||||||
|
* `java.lang.Boolean`
|
||||||
|
* `java.math.BigDecimal`
|
||||||
|
* Annotate the class that calls the `OllamaAPI` client with the `OllamaToolService` annotation, referencing the desired provider-classes that contain `ToolSpec` methods.
|
||||||
|
* Before calling the `OllamaAPI` chat request, call the method `OllamaAPI.registerAnnotatedTools()` method to add tools to the chat.
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
Let's say, we have an ollama4j service class that should ask a llm a specific tool based question.
|
||||||
|
|
||||||
|
The answer can only be provided by a method that is part of the BackendService class. To provide a tool for the llm, the following annotations can be used:
|
||||||
|
|
||||||
```java
|
```java
|
||||||
|
public class BackendService{
|
||||||
|
|
||||||
@ToolSpec(name = "current-fuel-price", desc = "Get current fuel price")
|
public BackendService(){}
|
||||||
public String getCurrentFuelPrice(Map<String, Object> arguments) {
|
|
||||||
String location = arguments.get("location").toString();
|
@ToolSpec(desc = "Computes the most important constant all around the globe!")
|
||||||
String fuelType = arguments.get("fuelType").toString();
|
public String computeMkeConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){
|
||||||
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
|
return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The caller API can then be written as:
|
||||||
|
```java
|
||||||
|
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||||
|
|
||||||
|
@OllamaToolService(providers = BackendService.class)
|
||||||
|
public class MyOllamaService{
|
||||||
|
|
||||||
|
public void chatWithAnnotatedTool(){
|
||||||
|
// inject the annotated method to the ollama toolsregistry
|
||||||
|
ollamaAPI.registerAnnotatedTools();
|
||||||
|
|
||||||
|
OllamaChatRequest requestModel = builder
|
||||||
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
|
"Compute the most important constant in the world using 5 digits")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The request should be the following:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model" : "llama3.2:1b",
|
||||||
|
"stream" : false,
|
||||||
|
"messages" : [ {
|
||||||
|
"role" : "user",
|
||||||
|
"content" : "Compute the most important constant in the world using 5 digits",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : [ ]
|
||||||
|
} ],
|
||||||
|
"tools" : [ {
|
||||||
|
"type" : "function",
|
||||||
|
"function" : {
|
||||||
|
"name" : "computeImportantConstant",
|
||||||
|
"description" : "Computes the most important constant all around the globe!",
|
||||||
|
"parameters" : {
|
||||||
|
"type" : "object",
|
||||||
|
"properties" : {
|
||||||
|
"noOfDigits" : {
|
||||||
|
"type" : "java.lang.Integer",
|
||||||
|
"description" : "Number of digits that shall be returned"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required" : [ "noOfDigits" ]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} ]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The result could be something like the following:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"chatHistory" : [ {
|
||||||
|
"role" : "user",
|
||||||
|
"content" : "Compute the most important constant in the world using 5 digits",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : [ ]
|
||||||
|
}, {
|
||||||
|
"role" : "assistant",
|
||||||
|
"content" : "",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : [ {
|
||||||
|
"function" : {
|
||||||
|
"name" : "computeImportantConstant",
|
||||||
|
"arguments" : {
|
||||||
|
"noOfDigits" : "5"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} ]
|
||||||
|
}, {
|
||||||
|
"role" : "tool",
|
||||||
|
"content" : "[TOOL_RESULTS]computeImportantConstant([noOfDigits]) : 1.51019[/TOOL_RESULTS]",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : null
|
||||||
|
}, {
|
||||||
|
"role" : "assistant",
|
||||||
|
"content" : "The most important constant in the world with 5 digits is: **1.51019**",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : null
|
||||||
|
} ],
|
||||||
|
"responseModel" : {
|
||||||
|
"model" : "llama3.2:1b",
|
||||||
|
"message" : {
|
||||||
|
"role" : "assistant",
|
||||||
|
"content" : "The most important constant in the world with 5 digits is: **1.51019**",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : null
|
||||||
|
},
|
||||||
|
"done" : true,
|
||||||
|
"error" : null,
|
||||||
|
"context" : null,
|
||||||
|
"created_at" : "2024-12-27T21:55:39.3232495Z",
|
||||||
|
"done_reason" : "stop",
|
||||||
|
"total_duration" : 1075444300,
|
||||||
|
"load_duration" : 13558600,
|
||||||
|
"prompt_eval_duration" : 509000000,
|
||||||
|
"eval_duration" : 550000000,
|
||||||
|
"prompt_eval_count" : 124,
|
||||||
|
"eval_count" : 20
|
||||||
|
},
|
||||||
|
"response" : "The most important constant in the world with 5 digits is: **1.51019**",
|
||||||
|
"responseTime" : 1075444300,
|
||||||
|
"httpStatusCode" : 200
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Potential Improvements
|
||||||
|
|
||||||
Instead of passing a map of args `Map<String, Object> arguments` to the tool functions, we could support passing
|
Instead of passing a map of args `Map<String, Object> arguments` to the tool functions, we could support passing
|
||||||
specific args separately with their data types. For example:
|
specific args separately with their data types. For example:
|
||||||
|
|
||||||
|
@ -15,11 +15,17 @@ import io.github.ollama4j.models.ps.ModelsProcessResponse;
|
|||||||
import io.github.ollama4j.models.request.*;
|
import io.github.ollama4j.models.request.*;
|
||||||
import io.github.ollama4j.models.response.*;
|
import io.github.ollama4j.models.response.*;
|
||||||
import io.github.ollama4j.tools.*;
|
import io.github.ollama4j.tools.*;
|
||||||
|
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||||
|
import io.github.ollama4j.tools.annotations.ToolProperty;
|
||||||
|
import io.github.ollama4j.tools.annotations.ToolSpec;
|
||||||
import io.github.ollama4j.utils.Options;
|
import io.github.ollama4j.utils.Options;
|
||||||
import io.github.ollama4j.utils.Utils;
|
import io.github.ollama4j.utils.Utils;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
|
import java.lang.reflect.InvocationTargetException;
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.lang.reflect.Parameter;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.net.URISyntaxException;
|
import java.net.URISyntaxException;
|
||||||
import java.net.http.HttpClient;
|
import java.net.http.HttpClient;
|
||||||
@ -603,6 +609,15 @@ public class OllamaAPI {
|
|||||||
OllamaToolsResult toolResult = new OllamaToolsResult();
|
OllamaToolsResult toolResult = new OllamaToolsResult();
|
||||||
Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
|
Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
|
||||||
|
|
||||||
|
if(!prompt.startsWith("[AVAILABLE_TOOLS]")){
|
||||||
|
final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder();
|
||||||
|
for(Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) {
|
||||||
|
promptBuilder.withToolSpecification(spec);
|
||||||
|
}
|
||||||
|
promptBuilder.withPrompt(prompt);
|
||||||
|
prompt = promptBuilder.build();
|
||||||
|
}
|
||||||
|
|
||||||
OllamaResult result = generate(model, prompt, raw, options, null);
|
OllamaResult result = generate(model, prompt, raw, options, null);
|
||||||
toolResult.setModelResult(result);
|
toolResult.setModelResult(result);
|
||||||
|
|
||||||
@ -811,6 +826,92 @@ public class OllamaAPI {
|
|||||||
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
|
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void registerAnnotatedTools() {
|
||||||
|
Class<?> callerClass = null;
|
||||||
|
try {
|
||||||
|
callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
|
||||||
|
} catch (ClassNotFoundException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class);
|
||||||
|
if(ollamaToolServiceAnnotation == null) {
|
||||||
|
throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
Class<?>[] providers = ollamaToolServiceAnnotation.providers();
|
||||||
|
|
||||||
|
for(Class<?> provider : providers){
|
||||||
|
Method[] methods = provider.getMethods();
|
||||||
|
for(Method m : methods) {
|
||||||
|
ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
|
||||||
|
if(toolSpec == null){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
|
||||||
|
String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
|
||||||
|
|
||||||
|
final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
|
||||||
|
LinkedHashMap<String,String> methodParams = new LinkedHashMap<>();
|
||||||
|
for (Parameter parameter : m.getParameters()) {
|
||||||
|
final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class);
|
||||||
|
String propType = parameter.getType().getTypeName();
|
||||||
|
if(toolPropertyAnn == null) {
|
||||||
|
methodParams.put(parameter.getName(),null);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
|
||||||
|
methodParams.put(propName,propType);
|
||||||
|
propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder()
|
||||||
|
.type(propType)
|
||||||
|
.description(toolPropertyAnn.desc())
|
||||||
|
.required(toolPropertyAnn.required())
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
|
||||||
|
List<String> reqProps = params.entrySet().stream()
|
||||||
|
.filter(e -> e.getValue().isRequired())
|
||||||
|
.map(Map.Entry::getKey)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder()
|
||||||
|
.functionName(operationName)
|
||||||
|
.functionDescription(operationDesc)
|
||||||
|
.toolPrompt(
|
||||||
|
Tools.PromptFuncDefinition.builder().type("function").function(
|
||||||
|
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||||
|
.name(operationName)
|
||||||
|
.description(operationDesc)
|
||||||
|
.parameters(
|
||||||
|
Tools.PromptFuncDefinition.Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(
|
||||||
|
params
|
||||||
|
)
|
||||||
|
.required(reqProps)
|
||||||
|
.build()
|
||||||
|
).build()
|
||||||
|
).build()
|
||||||
|
)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
try {
|
||||||
|
ReflectionalToolFunction reflectionalToolFunction =
|
||||||
|
new ReflectionalToolFunction(provider.getDeclaredConstructor().newInstance()
|
||||||
|
,m
|
||||||
|
,methodParams);
|
||||||
|
|
||||||
|
toolSpecification.setToolFunction(reflectionalToolFunction);
|
||||||
|
toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification);
|
||||||
|
} catch (InstantiationException | IllegalAccessException | InvocationTargetException |
|
||||||
|
NoSuchMethodException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Adds a custom role.
|
* Adds a custom role.
|
||||||
*
|
*
|
||||||
|
@ -0,0 +1,54 @@
|
|||||||
|
package io.github.ollama4j.tools;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specification of a {@link ToolFunction} that provides the implementation via java reflection calling.
|
||||||
|
*/
|
||||||
|
@Setter
|
||||||
|
@Getter
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class ReflectionalToolFunction implements ToolFunction{
|
||||||
|
|
||||||
|
private Object functionHolder;
|
||||||
|
private Method function;
|
||||||
|
private LinkedHashMap<String,String> propertyDefinition;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object apply(Map<String, Object> arguments) {
|
||||||
|
LinkedHashMap<String, Object> argumentsCopy = new LinkedHashMap<>(this.propertyDefinition);
|
||||||
|
for (Map.Entry<String,String> param : this.propertyDefinition.entrySet()){
|
||||||
|
argumentsCopy.replace(param.getKey(),typeCast(arguments.get(param.getKey()),param.getValue()));
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return function.invoke(functionHolder, argumentsCopy.values().toArray());
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException("Failed to invoke tool: " + function.getName(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Object typeCast(Object inputValue, String className) {
|
||||||
|
if(className == null || inputValue == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
String inputValueString = inputValue.toString();
|
||||||
|
switch (className) {
|
||||||
|
case "java.lang.Integer":
|
||||||
|
return Integer.parseInt(inputValueString);
|
||||||
|
case "java.lang.Boolean":
|
||||||
|
return Boolean.valueOf(inputValueString);
|
||||||
|
case "java.math.BigDecimal":
|
||||||
|
return new BigDecimal(inputValueString);
|
||||||
|
default:
|
||||||
|
return inputValueString;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,23 @@
|
|||||||
|
package io.github.ollama4j.tools.annotations;
|
||||||
|
|
||||||
|
import io.github.ollama4j.OllamaAPI;
|
||||||
|
|
||||||
|
import java.lang.annotation.ElementType;
|
||||||
|
import java.lang.annotation.Retention;
|
||||||
|
import java.lang.annotation.RetentionPolicy;
|
||||||
|
import java.lang.annotation.Target;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Annotates a class that calls {@link io.github.ollama4j.OllamaAPI} such that the Method
|
||||||
|
* {@link OllamaAPI#registerAnnotatedTools()} can be used to auto-register all provided classes (resp. all
|
||||||
|
* contained Methods of the provider classes annotated with {@link ToolSpec}).
|
||||||
|
*/
|
||||||
|
@Target(ElementType.TYPE)
|
||||||
|
@Retention(RetentionPolicy.RUNTIME)
|
||||||
|
public @interface OllamaToolService {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return Classes with no-arg constructor that will be used for tool-registration.
|
||||||
|
*/
|
||||||
|
Class<?>[] providers();
|
||||||
|
}
|
@ -0,0 +1,32 @@
|
|||||||
|
package io.github.ollama4j.tools.annotations;
|
||||||
|
|
||||||
|
import java.lang.annotation.ElementType;
|
||||||
|
import java.lang.annotation.Retention;
|
||||||
|
import java.lang.annotation.RetentionPolicy;
|
||||||
|
import java.lang.annotation.Target;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Annotates a Method Parameter in a {@link ToolSpec} annotated Method. A parameter annotated with this annotation will
|
||||||
|
* be part of the tool description that is sent to the llm for tool-calling.
|
||||||
|
*/
|
||||||
|
@Retention(RetentionPolicy.RUNTIME)
|
||||||
|
@Target(ElementType.PARAMETER)
|
||||||
|
public @interface ToolProperty {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return name of the parameter that is used for the tool description. Has to be set as depending on the caller,
|
||||||
|
* method name backtracking is not possible with reflection.
|
||||||
|
*/
|
||||||
|
String name();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return a detailed description of the parameter. This is used by the llm called to specify, which property has
|
||||||
|
* to be set by the llm and how this should be filled.
|
||||||
|
*/
|
||||||
|
String desc();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return tells the llm that it has to set a value for this property.
|
||||||
|
*/
|
||||||
|
boolean required() default true;
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
package io.github.ollama4j.tools.annotations;
|
||||||
|
|
||||||
|
import io.github.ollama4j.OllamaAPI;
|
||||||
|
|
||||||
|
import java.lang.annotation.ElementType;
|
||||||
|
import java.lang.annotation.Retention;
|
||||||
|
import java.lang.annotation.RetentionPolicy;
|
||||||
|
import java.lang.annotation.Target;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Annotates Methods of classes that should be registered as tools by {@link OllamaAPI#registerAnnotatedTools()}
|
||||||
|
* automatically.
|
||||||
|
*/
|
||||||
|
@Target(ElementType.METHOD)
|
||||||
|
@Retention(RetentionPolicy.RUNTIME)
|
||||||
|
public @interface ToolSpec {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return tool-name that the method should be used as. Defaults to the methods name.
|
||||||
|
*/
|
||||||
|
String name() default "";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return a detailed description of the method that can be interpreted by the llm, whether it should call the tool
|
||||||
|
* or not.
|
||||||
|
*/
|
||||||
|
String desc();
|
||||||
|
}
|
@ -7,8 +7,11 @@ import io.github.ollama4j.models.response.ModelDetail;
|
|||||||
import io.github.ollama4j.models.response.OllamaResult;
|
import io.github.ollama4j.models.response.OllamaResult;
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||||
|
import io.github.ollama4j.samples.AnnotatedTool;
|
||||||
|
import io.github.ollama4j.tools.OllamaToolCallsFunction;
|
||||||
import io.github.ollama4j.tools.ToolFunction;
|
import io.github.ollama4j.tools.ToolFunction;
|
||||||
import io.github.ollama4j.tools.Tools;
|
import io.github.ollama4j.tools.Tools;
|
||||||
|
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||||
import io.github.ollama4j.utils.OptionsBuilder;
|
import io.github.ollama4j.utils.OptionsBuilder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
@ -27,6 +30,8 @@ import java.util.*;
|
|||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
@OllamaToolService(providers = {AnnotatedTool.class}
|
||||||
|
)
|
||||||
class TestRealAPIs {
|
class TestRealAPIs {
|
||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
|
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
|
||||||
@ -229,7 +234,7 @@ class TestRealAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testChatWithTools() {
|
void testChatWithExplicitToolDefinition() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
ollamaAPI.setVerbose(true);
|
ollamaAPI.setVerbose(true);
|
||||||
@ -275,9 +280,10 @@ class TestRealAPIs {
|
|||||||
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||||
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||||
assertEquals(1, toolCalls.size());
|
assertEquals(1, toolCalls.size());
|
||||||
assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName());
|
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||||
assertEquals(1, toolCalls.get(0).getFunction().getArguments().size());
|
assertEquals("get-employee-details", function.getName());
|
||||||
Object employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name");
|
assertEquals(1, function.getArguments().size());
|
||||||
|
Object employeeName = function.getArguments().get("employee-name");
|
||||||
assertNotNull(employeeName);
|
assertNotNull(employeeName);
|
||||||
assertEquals("Rahul Kumar",employeeName);
|
assertEquals("Rahul Kumar",employeeName);
|
||||||
assertTrue(chatResult.getChatHistory().size()>2);
|
assertTrue(chatResult.getChatHistory().size()>2);
|
||||||
@ -288,6 +294,82 @@ class TestRealAPIs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(3)
|
||||||
|
void testChatWithAnnotatedToolsAndSingleParam() {
|
||||||
|
testEndpointReachability();
|
||||||
|
try {
|
||||||
|
ollamaAPI.setVerbose(true);
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||||
|
|
||||||
|
ollamaAPI.registerAnnotatedTools();
|
||||||
|
|
||||||
|
OllamaChatRequest requestModel = builder
|
||||||
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
|
"Compute the most important constant in the world using 5 digits")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
|
assertNotNull(chatResult);
|
||||||
|
assertNotNull(chatResult.getResponseModel());
|
||||||
|
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||||
|
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||||
|
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||||
|
assertEquals(1, toolCalls.size());
|
||||||
|
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||||
|
assertEquals("computeImportantConstant", function.getName());
|
||||||
|
assertEquals(1, function.getArguments().size());
|
||||||
|
Object noOfDigits = function.getArguments().get("noOfDigits");
|
||||||
|
assertNotNull(noOfDigits);
|
||||||
|
assertEquals("5",noOfDigits);
|
||||||
|
assertTrue(chatResult.getChatHistory().size()>2);
|
||||||
|
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||||
|
assertNull(finalToolCalls);
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(3)
|
||||||
|
void testChatWithAnnotatedToolsAndMultipleParams() {
|
||||||
|
testEndpointReachability();
|
||||||
|
try {
|
||||||
|
ollamaAPI.setVerbose(true);
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||||
|
|
||||||
|
ollamaAPI.registerAnnotatedTools();
|
||||||
|
|
||||||
|
OllamaChatRequest requestModel = builder
|
||||||
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
|
"Greet Pedro with a lot of hearts and respond to me, " +
|
||||||
|
"and state how many emojis have been in your greeting")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
|
assertNotNull(chatResult);
|
||||||
|
assertNotNull(chatResult.getResponseModel());
|
||||||
|
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||||
|
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||||
|
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||||
|
assertEquals(1, toolCalls.size());
|
||||||
|
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||||
|
assertEquals("sayHello", function.getName());
|
||||||
|
assertEquals(2, function.getArguments().size());
|
||||||
|
Object name = function.getArguments().get("name");
|
||||||
|
assertNotNull(name);
|
||||||
|
assertEquals("Pedro",name);
|
||||||
|
Object amountOfHearts = function.getArguments().get("amountOfHearts");
|
||||||
|
assertNotNull(amountOfHearts);
|
||||||
|
assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1);
|
||||||
|
assertTrue(chatResult.getChatHistory().size()>2);
|
||||||
|
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||||
|
assertNull(finalToolCalls);
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testChatWithToolsAndStream() {
|
void testChatWithToolsAndStream() {
|
||||||
|
21
src/test/java/io/github/ollama4j/samples/AnnotatedTool.java
Normal file
21
src/test/java/io/github/ollama4j/samples/AnnotatedTool.java
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package io.github.ollama4j.samples;
|
||||||
|
|
||||||
|
import io.github.ollama4j.tools.annotations.ToolProperty;
|
||||||
|
import io.github.ollama4j.tools.annotations.ToolSpec;
|
||||||
|
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
|
||||||
|
public class AnnotatedTool {
|
||||||
|
|
||||||
|
@ToolSpec(desc = "Computes the most important constant all around the globe!")
|
||||||
|
public String computeImportantConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){
|
||||||
|
return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@ToolSpec(desc = "Says hello to a friend!")
|
||||||
|
public String sayHello(@ToolProperty(name = "name",desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name="amountOfHearts",desc = "amount of heart emojis that should be used", required = false) Integer amountOfHearts) {
|
||||||
|
String hearts = amountOfHearts!=null ? "♡".repeat(amountOfHearts) : "";
|
||||||
|
return "Hello " + name +" ("+someRandomProperty+") " + hearts;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user