forked from Mirror/ollama4j
		
	Adds first approach to annotation based tool callings using basic java reflection
This commit is contained in:
		@@ -15,11 +15,17 @@ import io.github.ollama4j.models.ps.ModelsProcessResponse;
 | 
				
			|||||||
import io.github.ollama4j.models.request.*;
 | 
					import io.github.ollama4j.models.request.*;
 | 
				
			||||||
import io.github.ollama4j.models.response.*;
 | 
					import io.github.ollama4j.models.response.*;
 | 
				
			||||||
import io.github.ollama4j.tools.*;
 | 
					import io.github.ollama4j.tools.*;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.annotations.OllamaToolService;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.annotations.ToolProperty;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.annotations.ToolSpec;
 | 
				
			||||||
import io.github.ollama4j.utils.Options;
 | 
					import io.github.ollama4j.utils.Options;
 | 
				
			||||||
import io.github.ollama4j.utils.Utils;
 | 
					import io.github.ollama4j.utils.Utils;
 | 
				
			||||||
import lombok.Setter;
 | 
					import lombok.Setter;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.io.*;
 | 
					import java.io.*;
 | 
				
			||||||
 | 
					import java.lang.reflect.InvocationTargetException;
 | 
				
			||||||
 | 
					import java.lang.reflect.Method;
 | 
				
			||||||
 | 
					import java.lang.reflect.Parameter;
 | 
				
			||||||
import java.net.URI;
 | 
					import java.net.URI;
 | 
				
			||||||
import java.net.URISyntaxException;
 | 
					import java.net.URISyntaxException;
 | 
				
			||||||
import java.net.http.HttpClient;
 | 
					import java.net.http.HttpClient;
 | 
				
			||||||
@@ -603,6 +609,15 @@ public class OllamaAPI {
 | 
				
			|||||||
        OllamaToolsResult toolResult = new OllamaToolsResult();
 | 
					        OllamaToolsResult toolResult = new OllamaToolsResult();
 | 
				
			||||||
        Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
 | 
					        Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if(!prompt.startsWith("[AVAILABLE_TOOLS]")){
 | 
				
			||||||
 | 
					            final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder();
 | 
				
			||||||
 | 
					            for(Tools.ToolSpecification spec :  toolRegistry.getRegisteredSpecs()) {
 | 
				
			||||||
 | 
					                promptBuilder.withToolSpecification(spec);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            promptBuilder.withPrompt(prompt);
 | 
				
			||||||
 | 
					            prompt = promptBuilder.build();
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        OllamaResult result = generate(model, prompt, raw, options, null);
 | 
					        OllamaResult result = generate(model, prompt, raw, options, null);
 | 
				
			||||||
        toolResult.setModelResult(result);
 | 
					        toolResult.setModelResult(result);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -811,6 +826,94 @@ public class OllamaAPI {
 | 
				
			|||||||
        toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
 | 
					        toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public void registerAnnotatedTools()  {
 | 
				
			||||||
 | 
					        Class<?> callerClass = null;
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
 | 
				
			||||||
 | 
					        } catch (ClassNotFoundException e) {
 | 
				
			||||||
 | 
					            throw new RuntimeException(e);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class);
 | 
				
			||||||
 | 
					        if(ollamaToolServiceAnnotation == null) {
 | 
				
			||||||
 | 
					            throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Class<?>[] providers = ollamaToolServiceAnnotation.providers();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for(Class<?> provider : providers){
 | 
				
			||||||
 | 
					            System.err.println("Provider: " + provider.getName());
 | 
				
			||||||
 | 
					            Method[] methods = provider.getMethods();
 | 
				
			||||||
 | 
					            for(Method m : methods) {
 | 
				
			||||||
 | 
					                ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
 | 
				
			||||||
 | 
					                if(toolSpec == null){
 | 
				
			||||||
 | 
					                    continue;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
 | 
				
			||||||
 | 
					                String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
 | 
				
			||||||
 | 
					                System.err.println("Method: " + operationName);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
 | 
				
			||||||
 | 
					                LinkedHashMap<String,String> methodParams = new LinkedHashMap<>();
 | 
				
			||||||
 | 
					                for (Parameter parameter : m.getParameters()) {
 | 
				
			||||||
 | 
					                    final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class);
 | 
				
			||||||
 | 
					                    String propType = parameter.getType().getTypeName();
 | 
				
			||||||
 | 
					                    if(toolPropertyAnn == null) {
 | 
				
			||||||
 | 
					                        methodParams.put(parameter.getName(),null);
 | 
				
			||||||
 | 
					                        continue;
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                    String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
 | 
				
			||||||
 | 
					                    methodParams.put(propName,propType);
 | 
				
			||||||
 | 
					                    propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder()
 | 
				
			||||||
 | 
					                            .type(propType)
 | 
				
			||||||
 | 
					                            .description(toolPropertyAnn.desc())
 | 
				
			||||||
 | 
					                            .required(toolPropertyAnn.required())
 | 
				
			||||||
 | 
					                            .build());
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
 | 
				
			||||||
 | 
					                List<String> reqProps = params.entrySet().stream()
 | 
				
			||||||
 | 
					                        .filter(e -> e.getValue().isRequired())
 | 
				
			||||||
 | 
					                        .map(Map.Entry::getKey)
 | 
				
			||||||
 | 
					                        .collect(Collectors.toList());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder()
 | 
				
			||||||
 | 
					                        .functionName(operationName)
 | 
				
			||||||
 | 
					                        .functionDescription(operationDesc)
 | 
				
			||||||
 | 
					                        .toolPrompt(
 | 
				
			||||||
 | 
					                                Tools.PromptFuncDefinition.builder().type("function").function(
 | 
				
			||||||
 | 
					                                        Tools.PromptFuncDefinition.PromptFuncSpec.builder()
 | 
				
			||||||
 | 
					                                                .name(operationName)
 | 
				
			||||||
 | 
					                                                .description(operationDesc)
 | 
				
			||||||
 | 
					                                                .parameters(
 | 
				
			||||||
 | 
					                                                        Tools.PromptFuncDefinition.Parameters.builder()
 | 
				
			||||||
 | 
					                                                                .type("object")
 | 
				
			||||||
 | 
					                                                                .properties(
 | 
				
			||||||
 | 
					                                                                        params
 | 
				
			||||||
 | 
					                                                                )
 | 
				
			||||||
 | 
					                                                                .required(reqProps)
 | 
				
			||||||
 | 
					                                                                .build()
 | 
				
			||||||
 | 
					                                                ).build()
 | 
				
			||||||
 | 
					                                ).build()
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                        .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                try {
 | 
				
			||||||
 | 
					                    ReflectionalToolFunction reflectionalToolFunction =
 | 
				
			||||||
 | 
					                            new ReflectionalToolFunction(provider.getDeclaredConstructor().newInstance()
 | 
				
			||||||
 | 
					                                    ,m
 | 
				
			||||||
 | 
					                                    ,methodParams);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    toolSpecification.setToolFunction(reflectionalToolFunction);
 | 
				
			||||||
 | 
					                    toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification);
 | 
				
			||||||
 | 
					                } catch (InstantiationException | IllegalAccessException | InvocationTargetException |
 | 
				
			||||||
 | 
					                         NoSuchMethodException e) {
 | 
				
			||||||
 | 
					                    throw new RuntimeException(e);
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Adds a custom role.
 | 
					     * Adds a custom role.
 | 
				
			||||||
     *
 | 
					     *
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,49 @@
 | 
				
			|||||||
 | 
					package io.github.ollama4j.tools;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import lombok.AllArgsConstructor;
 | 
				
			||||||
 | 
					import lombok.Getter;
 | 
				
			||||||
 | 
					import lombok.Setter;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.lang.reflect.Method;
 | 
				
			||||||
 | 
					import java.util.LinkedHashMap;
 | 
				
			||||||
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Setter
 | 
				
			||||||
 | 
					@Getter
 | 
				
			||||||
 | 
					@AllArgsConstructor
 | 
				
			||||||
 | 
					public class ReflectionalToolFunction implements ToolFunction{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private Object functionHolder;
 | 
				
			||||||
 | 
					    private Method function;
 | 
				
			||||||
 | 
					    private LinkedHashMap<String,String> propertyDefinition;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Override
 | 
				
			||||||
 | 
					    public Object apply(Map<String, Object> arguments) {
 | 
				
			||||||
 | 
					        LinkedHashMap<String, Object> argumentsCopy = new LinkedHashMap<>(this.propertyDefinition);
 | 
				
			||||||
 | 
					        for (Map.Entry<String,String> param : this.propertyDefinition.entrySet()){
 | 
				
			||||||
 | 
					            argumentsCopy.replace(param.getKey(),typeCast(arguments.get(param.getKey()),param.getValue()));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            return function.invoke(functionHolder, argumentsCopy.values().toArray());
 | 
				
			||||||
 | 
					        } catch (Exception e) {
 | 
				
			||||||
 | 
					            throw new RuntimeException("Failed to invoke tool: " + function.getName(), e);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private Object typeCast(Object inputValue, String className) {
 | 
				
			||||||
 | 
					        if(className == null || inputValue == null) {
 | 
				
			||||||
 | 
					            return null;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        String inputValueString = inputValue.toString();
 | 
				
			||||||
 | 
					        if("java.lang.Integer".equals(className)){
 | 
				
			||||||
 | 
					            return Integer.parseInt(inputValueString);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if("java.lang.Boolean".equals(className)){
 | 
				
			||||||
 | 
					            return Boolean.valueOf(inputValueString);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        else {
 | 
				
			||||||
 | 
					            return inputValueString;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,13 @@
 | 
				
			|||||||
 | 
					package io.github.ollama4j.tools.annotations;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.lang.annotation.ElementType;
 | 
				
			||||||
 | 
					import java.lang.annotation.Retention;
 | 
				
			||||||
 | 
					import java.lang.annotation.RetentionPolicy;
 | 
				
			||||||
 | 
					import java.lang.annotation.Target;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Target(ElementType.TYPE)
 | 
				
			||||||
 | 
					@Retention(RetentionPolicy.RUNTIME)
 | 
				
			||||||
 | 
					public @interface OllamaToolService {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Class<?>[] providers();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,17 @@
 | 
				
			|||||||
 | 
					package io.github.ollama4j.tools.annotations;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.lang.annotation.ElementType;
 | 
				
			||||||
 | 
					import java.lang.annotation.Retention;
 | 
				
			||||||
 | 
					import java.lang.annotation.RetentionPolicy;
 | 
				
			||||||
 | 
					import java.lang.annotation.Target;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Retention(RetentionPolicy.RUNTIME)
 | 
				
			||||||
 | 
					@Target(ElementType.PARAMETER)
 | 
				
			||||||
 | 
					public @interface ToolProperty {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    String name();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    String desc();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    boolean required() default true;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,15 @@
 | 
				
			|||||||
 | 
					package io.github.ollama4j.tools.annotations;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.lang.annotation.ElementType;
 | 
				
			||||||
 | 
					import java.lang.annotation.Retention;
 | 
				
			||||||
 | 
					import java.lang.annotation.RetentionPolicy;
 | 
				
			||||||
 | 
					import java.lang.annotation.Target;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Target(ElementType.METHOD)
 | 
				
			||||||
 | 
					@Retention(RetentionPolicy.RUNTIME)
 | 
				
			||||||
 | 
					public @interface ToolSpec {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    String name() default "";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    String desc();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -7,8 +7,11 @@ import io.github.ollama4j.models.response.ModelDetail;
 | 
				
			|||||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
					import io.github.ollama4j.models.response.OllamaResult;
 | 
				
			||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
 | 
					import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
 | 
				
			||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
					import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
				
			||||||
 | 
					import io.github.ollama4j.samples.AnnotatedTool;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.OllamaToolCallsFunction;
 | 
				
			||||||
import io.github.ollama4j.tools.ToolFunction;
 | 
					import io.github.ollama4j.tools.ToolFunction;
 | 
				
			||||||
import io.github.ollama4j.tools.Tools;
 | 
					import io.github.ollama4j.tools.Tools;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.annotations.OllamaToolService;
 | 
				
			||||||
import io.github.ollama4j.utils.OptionsBuilder;
 | 
					import io.github.ollama4j.utils.OptionsBuilder;
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
import org.junit.jupiter.api.BeforeEach;
 | 
					import org.junit.jupiter.api.BeforeEach;
 | 
				
			||||||
@@ -27,6 +30,8 @@ import java.util.*;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import static org.junit.jupiter.api.Assertions.*;
 | 
					import static org.junit.jupiter.api.Assertions.*;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@OllamaToolService(providers = {AnnotatedTool.class}
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
class TestRealAPIs {
 | 
					class TestRealAPIs {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
 | 
					    private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
 | 
				
			||||||
@@ -229,7 +234,7 @@ class TestRealAPIs {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    @Order(3)
 | 
					    @Order(3)
 | 
				
			||||||
    void testChatWithTools() {
 | 
					    void testChatWithExplicitToolDefinition() {
 | 
				
			||||||
        testEndpointReachability();
 | 
					        testEndpointReachability();
 | 
				
			||||||
        try {
 | 
					        try {
 | 
				
			||||||
            ollamaAPI.setVerbose(true);
 | 
					            ollamaAPI.setVerbose(true);
 | 
				
			||||||
@@ -275,9 +280,10 @@ class TestRealAPIs {
 | 
				
			|||||||
            assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
 | 
					            assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
 | 
				
			||||||
            List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
 | 
					            List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
 | 
				
			||||||
            assertEquals(1, toolCalls.size());
 | 
					            assertEquals(1, toolCalls.size());
 | 
				
			||||||
            assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName());
 | 
					            OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
 | 
				
			||||||
            assertEquals(1, toolCalls.get(0).getFunction().getArguments().size());
 | 
					            assertEquals("get-employee-details", function.getName());
 | 
				
			||||||
            Object employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name");
 | 
					            assertEquals(1, function.getArguments().size());
 | 
				
			||||||
 | 
					            Object employeeName = function.getArguments().get("employee-name");
 | 
				
			||||||
            assertNotNull(employeeName);
 | 
					            assertNotNull(employeeName);
 | 
				
			||||||
            assertEquals("Rahul Kumar",employeeName);
 | 
					            assertEquals("Rahul Kumar",employeeName);
 | 
				
			||||||
            assertTrue(chatResult.getChatHistory().size()>2);
 | 
					            assertTrue(chatResult.getChatHistory().size()>2);
 | 
				
			||||||
@@ -288,6 +294,82 @@ class TestRealAPIs {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    @Order(3)
 | 
				
			||||||
 | 
					    void testChatWithAnnotatedToolsAndSingleParam() {
 | 
				
			||||||
 | 
					        testEndpointReachability();
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            ollamaAPI.setVerbose(true);
 | 
				
			||||||
 | 
					            OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            ollamaAPI.registerAnnotatedTools();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OllamaChatRequest requestModel = builder
 | 
				
			||||||
 | 
					                    .withMessage(OllamaChatMessageRole.USER,
 | 
				
			||||||
 | 
					                            "Compute the most important constant in the world using 5 digits")
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
 | 
					            assertNotNull(chatResult);
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel());
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
				
			||||||
 | 
					            assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
 | 
				
			||||||
 | 
					            List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
 | 
				
			||||||
 | 
					            assertEquals(1, toolCalls.size());
 | 
				
			||||||
 | 
					            OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
 | 
				
			||||||
 | 
					            assertEquals("computeMkeConstant", function.getName());
 | 
				
			||||||
 | 
					            assertEquals(1, function.getArguments().size());
 | 
				
			||||||
 | 
					            Object noOfDigits = function.getArguments().get("noOfDigits");
 | 
				
			||||||
 | 
					            assertNotNull(noOfDigits);
 | 
				
			||||||
 | 
					            assertEquals("5",noOfDigits);
 | 
				
			||||||
 | 
					            assertTrue(chatResult.getChatHistory().size()>2);
 | 
				
			||||||
 | 
					            List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
 | 
				
			||||||
 | 
					            assertNull(finalToolCalls);
 | 
				
			||||||
 | 
					        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
				
			||||||
 | 
					            fail(e);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    @Order(3)
 | 
				
			||||||
 | 
					    void testChatWithAnnotatedToolsAndMultipleParams() {
 | 
				
			||||||
 | 
					        testEndpointReachability();
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            ollamaAPI.setVerbose(true);
 | 
				
			||||||
 | 
					            OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            ollamaAPI.registerAnnotatedTools();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OllamaChatRequest requestModel = builder
 | 
				
			||||||
 | 
					                    .withMessage(OllamaChatMessageRole.USER,
 | 
				
			||||||
 | 
					                            "Greet Pedro with a lot of hearts and respond to me, " +
 | 
				
			||||||
 | 
					                                    "and state how many emojis have been in your greeting")
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
 | 
					            assertNotNull(chatResult);
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel());
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
				
			||||||
 | 
					            assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
 | 
				
			||||||
 | 
					            List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
 | 
				
			||||||
 | 
					            assertEquals(1, toolCalls.size());
 | 
				
			||||||
 | 
					            OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
 | 
				
			||||||
 | 
					            assertEquals("sayHello", function.getName());
 | 
				
			||||||
 | 
					            assertEquals(2, function.getArguments().size());
 | 
				
			||||||
 | 
					            Object name = function.getArguments().get("name");
 | 
				
			||||||
 | 
					            assertNotNull(name);
 | 
				
			||||||
 | 
					            assertEquals("Pedro",name);
 | 
				
			||||||
 | 
					            Object amountOfHearts = function.getArguments().get("amountOfHearts");
 | 
				
			||||||
 | 
					            assertNotNull(amountOfHearts);
 | 
				
			||||||
 | 
					            assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1);
 | 
				
			||||||
 | 
					            assertTrue(chatResult.getChatHistory().size()>2);
 | 
				
			||||||
 | 
					            List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
 | 
				
			||||||
 | 
					            assertNull(finalToolCalls);
 | 
				
			||||||
 | 
					        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
				
			||||||
 | 
					            fail(e);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    @Order(3)
 | 
					    @Order(3)
 | 
				
			||||||
    void testChatWithToolsAndStream() {
 | 
					    void testChatWithToolsAndStream() {
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										21
									
								
								src/test/java/io/github/ollama4j/samples/AnnotatedTool.java
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								src/test/java/io/github/ollama4j/samples/AnnotatedTool.java
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package io.github.ollama4j.samples;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.annotations.ToolProperty;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.annotations.ToolSpec;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.math.BigDecimal;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					public class AnnotatedTool {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @ToolSpec(desc = "Computes the most important constant all around the globe!")
 | 
				
			||||||
 | 
					    public String computeMkeConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){
 | 
				
			||||||
 | 
					        return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @ToolSpec(desc = "Says hello to a friend!")
 | 
				
			||||||
 | 
					    public String sayHello(@ToolProperty(name = "name",desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name="amountOfHearts",desc = "amount of heart emojis that should be used",  required = false) Integer amountOfHearts) {
 | 
				
			||||||
 | 
					        String hearts = amountOfHearts!=null ? "♡".repeat(amountOfHearts) : "";
 | 
				
			||||||
 | 
					        return "Hello " + name +" ("+someRandomProperty+") " + hearts;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user