mirror of
				https://github.com/amithkoujalgi/ollama4j.git
				synced 2025-11-04 02:20:50 +01:00 
			
		
		
		
	Refactored tools API
Signed-off-by: koujalgi.amith@gmail.com <koujalgi.amith@gmail.com>
This commit is contained in:
		@@ -1,6 +1,8 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core;
 | 
			
		||||
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolInvocationException;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolNotFoundException;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.*;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
 | 
			
		||||
@@ -14,6 +16,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.request.*;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.tools.*;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
 | 
			
		||||
import lombok.Setter;
 | 
			
		||||
import org.slf4j.Logger;
 | 
			
		||||
import org.slf4j.LoggerFactory;
 | 
			
		||||
 | 
			
		||||
@@ -37,10 +40,22 @@ public class OllamaAPI {
 | 
			
		||||
 | 
			
		||||
    private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
 | 
			
		||||
    private final String host;
 | 
			
		||||
    /**
 | 
			
		||||
     * -- SETTER --
 | 
			
		||||
     * Set request timeout in seconds. Default is 3 seconds.
 | 
			
		||||
     */
 | 
			
		||||
    @Setter
 | 
			
		||||
    private long requestTimeoutSeconds = 10;
 | 
			
		||||
    /**
 | 
			
		||||
     * -- SETTER --
 | 
			
		||||
     * Set/unset logging of responses
 | 
			
		||||
     */
 | 
			
		||||
    @Setter
 | 
			
		||||
    private boolean verbose = true;
 | 
			
		||||
    private BasicAuth basicAuth;
 | 
			
		||||
 | 
			
		||||
    private final ToolRegistry toolRegistry = new ToolRegistry();
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Instantiates the Ollama API.
 | 
			
		||||
     *
 | 
			
		||||
@@ -54,24 +69,6 @@ public class OllamaAPI {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Set request timeout in seconds. Default is 3 seconds.
 | 
			
		||||
     *
 | 
			
		||||
     * @param requestTimeoutSeconds the request timeout in seconds
 | 
			
		||||
     */
 | 
			
		||||
    public void setRequestTimeoutSeconds(long requestTimeoutSeconds) {
 | 
			
		||||
        this.requestTimeoutSeconds = requestTimeoutSeconds;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Set/unset logging of responses
 | 
			
		||||
     *
 | 
			
		||||
     * @param verbose true/false
 | 
			
		||||
     */
 | 
			
		||||
    public void setVerbose(boolean verbose) {
 | 
			
		||||
        this.verbose = verbose;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Set basic authentication for accessing Ollama server that's behind a reverse-proxy/gateway.
 | 
			
		||||
     *
 | 
			
		||||
@@ -383,7 +380,6 @@ public class OllamaAPI {
 | 
			
		||||
     *
 | 
			
		||||
     * @param model   The name or identifier of the AI model to use for generating the response.
 | 
			
		||||
     * @param prompt  The input text or prompt to provide to the AI model.
 | 
			
		||||
     * @param raw     In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context.
 | 
			
		||||
     * @param options Additional options or configurations to use when generating the response.
 | 
			
		||||
     * @return {@link OllamaToolsResult} An OllamaToolsResult object containing the response from the AI model and the results of invoking the tools on that output.
 | 
			
		||||
     * @throws OllamaBaseException  If there is an error related to the Ollama API or service.
 | 
			
		||||
@@ -391,17 +387,23 @@ public class OllamaAPI {
 | 
			
		||||
     * @throws InterruptedException If the method is interrupted while waiting for the AI model
 | 
			
		||||
     *                              to generate the response or for the tools to be invoked.
 | 
			
		||||
     */
 | 
			
		||||
    public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options)
 | 
			
		||||
            throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
    public OllamaToolsResult generateWithTools(String model, String prompt, Options options)
 | 
			
		||||
            throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
 | 
			
		||||
        boolean raw = true;
 | 
			
		||||
        OllamaToolsResult toolResult = new OllamaToolsResult();
 | 
			
		||||
        Map<ToolDef, Object> toolResults = new HashMap<>();
 | 
			
		||||
        Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
 | 
			
		||||
 | 
			
		||||
        OllamaResult result = generate(model, prompt, raw, options, null);
 | 
			
		||||
        toolResult.setModelResult(result);
 | 
			
		||||
 | 
			
		||||
        List<ToolDef> toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class));
 | 
			
		||||
        for (ToolDef toolDef : toolDefs) {
 | 
			
		||||
            toolResults.put(toolDef, invokeTool(toolDef));
 | 
			
		||||
        String toolsResponse = result.getResponse();
 | 
			
		||||
        if (toolsResponse.contains("[TOOL_CALLS]")) {
 | 
			
		||||
            toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        List<ToolFunctionCallSpec> toolFunctionCallSpecs = Utils.getObjectMapper().readValue(toolsResponse, Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
 | 
			
		||||
        for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
 | 
			
		||||
            toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
 | 
			
		||||
        }
 | 
			
		||||
        toolResult.setToolResults(toolResults);
 | 
			
		||||
        return toolResult;
 | 
			
		||||
@@ -556,8 +558,8 @@ public class OllamaAPI {
 | 
			
		||||
        return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public void registerTool(MistralTools.ToolSpecification toolSpecification) {
 | 
			
		||||
        ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
 | 
			
		||||
    public void registerTool(Tools.ToolSpecification toolSpecification) {
 | 
			
		||||
        toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // technical private methods //
 | 
			
		||||
@@ -622,18 +624,20 @@ public class OllamaAPI {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    private Object invokeTool(ToolDef toolDef) {
 | 
			
		||||
    private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
 | 
			
		||||
        try {
 | 
			
		||||
            String methodName = toolDef.getName();
 | 
			
		||||
            Map<String, Object> arguments = toolDef.getArguments();
 | 
			
		||||
            DynamicFunction function = ToolRegistry.getFunction(methodName);
 | 
			
		||||
            String methodName = toolFunctionCallSpec.getName();
 | 
			
		||||
            Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
 | 
			
		||||
            ToolFunction function = toolRegistry.getFunction(methodName);
 | 
			
		||||
            if (verbose) {
 | 
			
		||||
                logger.debug("Invoking function {} with arguments {}", methodName, arguments);
 | 
			
		||||
            }
 | 
			
		||||
            if (function == null) {
 | 
			
		||||
                throw new IllegalArgumentException("No such tool: " + methodName);
 | 
			
		||||
                throw new ToolNotFoundException("No such tool: " + methodName);
 | 
			
		||||
            }
 | 
			
		||||
            return function.apply(arguments);
 | 
			
		||||
        } catch (Exception e) {
 | 
			
		||||
            e.printStackTrace();
 | 
			
		||||
            return "Error calling tool: " + e.getMessage();
 | 
			
		||||
            throw new ToolInvocationException("Failed to invoke tool: " + toolFunctionCallSpec.getName(), e);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,8 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.exceptions;
 | 
			
		||||
 | 
			
		||||
public class ToolInvocationException extends Exception {
 | 
			
		||||
 | 
			
		||||
    public ToolInvocationException(String s, Exception e) {
 | 
			
		||||
        super(s, e);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,8 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.exceptions;
 | 
			
		||||
 | 
			
		||||
public class ToolNotFoundException extends Exception {
 | 
			
		||||
 | 
			
		||||
    public ToolNotFoundException(String s) {
 | 
			
		||||
        super(s);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -5,6 +5,8 @@ import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@@ -12,5 +14,22 @@ import java.util.Map;
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
public class OllamaToolsResult {
 | 
			
		||||
    private OllamaResult modelResult;
 | 
			
		||||
    private Map<ToolDef, Object> toolResults;
 | 
			
		||||
    private Map<ToolFunctionCallSpec, Object> toolResults;
 | 
			
		||||
 | 
			
		||||
    public List<ToolResult> getToolResults() {
 | 
			
		||||
        List<ToolResult> results = new ArrayList<>();
 | 
			
		||||
        for (Map.Entry<ToolFunctionCallSpec, Object> r : this.toolResults.entrySet()) {
 | 
			
		||||
            results.add(new ToolResult(r.getKey().getName(), r.getKey().getArguments(), r.getValue()));
 | 
			
		||||
        }
 | 
			
		||||
        return results;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Data
 | 
			
		||||
    @NoArgsConstructor
 | 
			
		||||
    @AllArgsConstructor
 | 
			
		||||
    public static class ToolResult {
 | 
			
		||||
        private String functionName;
 | 
			
		||||
        private Map<String, Object> functionArguments;
 | 
			
		||||
        private Object result;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,6 @@ package io.github.amithkoujalgi.ollama4j.core.tools;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@FunctionalInterface
 | 
			
		||||
public interface DynamicFunction {
 | 
			
		||||
public interface ToolFunction {
 | 
			
		||||
    Object apply(Map<String, Object> arguments);
 | 
			
		||||
}
 | 
			
		||||
@@ -9,10 +9,8 @@ import java.util.Map;
 | 
			
		||||
@Data
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class ToolDef {
 | 
			
		||||
 | 
			
		||||
public class ToolFunctionCallSpec {
 | 
			
		||||
    private String name;
 | 
			
		||||
    private Map<String, Object> arguments;
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -4,14 +4,13 @@ import java.util.HashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
public class ToolRegistry {
 | 
			
		||||
    private static final Map<String, DynamicFunction> functionMap = new HashMap<>();
 | 
			
		||||
    private final Map<String, ToolFunction> functionMap = new HashMap<>();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public static DynamicFunction getFunction(String name) {
 | 
			
		||||
    public ToolFunction getFunction(String name) {
 | 
			
		||||
        return functionMap.get(name);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static void addFunction(String name, DynamicFunction function) {
 | 
			
		||||
    public void addFunction(String name, ToolFunction function) {
 | 
			
		||||
        functionMap.put(name, function);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -14,14 +14,14 @@ import java.util.HashMap;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
public class MistralTools {
 | 
			
		||||
public class Tools {
 | 
			
		||||
    @Data
 | 
			
		||||
    @Builder
 | 
			
		||||
    public static class ToolSpecification {
 | 
			
		||||
        private String functionName;
 | 
			
		||||
        private String functionDesc;
 | 
			
		||||
        private Map<String, PromptFuncDefinition.Property> props;
 | 
			
		||||
        private DynamicFunction toolDefinition;
 | 
			
		||||
        private String functionDescription;
 | 
			
		||||
        private Map<String, PromptFuncDefinition.Property> properties;
 | 
			
		||||
        private ToolFunction toolDefinition;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Data
 | 
			
		||||
@@ -90,14 +90,14 @@ public class MistralTools {
 | 
			
		||||
 | 
			
		||||
            PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
 | 
			
		||||
            functionDetail.setName(spec.getFunctionName());
 | 
			
		||||
            functionDetail.setDescription(spec.getFunctionDesc());
 | 
			
		||||
            functionDetail.setDescription(spec.getFunctionDescription());
 | 
			
		||||
 | 
			
		||||
            PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
 | 
			
		||||
            parameters.setType("object");
 | 
			
		||||
            parameters.setProperties(spec.getProps());
 | 
			
		||||
            parameters.setProperties(spec.getProperties());
 | 
			
		||||
 | 
			
		||||
            List<String> requiredValues = new ArrayList<>();
 | 
			
		||||
            for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProps().entrySet()) {
 | 
			
		||||
            for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) {
 | 
			
		||||
                if (p.getValue().isRequired()) {
 | 
			
		||||
                    requiredValues.add(p.getKey());
 | 
			
		||||
                }
 | 
			
		||||
@@ -109,31 +109,5 @@ public class MistralTools {
 | 
			
		||||
            tools.add(def);
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
//
 | 
			
		||||
//        public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map<String, PromptFuncDefinition.Property> props) {
 | 
			
		||||
//            PromptFuncDefinition def = new PromptFuncDefinition();
 | 
			
		||||
//            def.setType("function");
 | 
			
		||||
//
 | 
			
		||||
//            PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
 | 
			
		||||
//            functionDetail.setName(functionName);
 | 
			
		||||
//            functionDetail.setDescription(functionDesc);
 | 
			
		||||
//
 | 
			
		||||
//            PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
 | 
			
		||||
//            parameters.setType("object");
 | 
			
		||||
//            parameters.setProperties(props);
 | 
			
		||||
//
 | 
			
		||||
//            List<String> requiredValues = new ArrayList<>();
 | 
			
		||||
//            for (Map.Entry<String, PromptFuncDefinition.Property> p : props.entrySet()) {
 | 
			
		||||
//                if (p.getValue().isRequired()) {
 | 
			
		||||
//                    requiredValues.add(p.getKey());
 | 
			
		||||
//                }
 | 
			
		||||
//            }
 | 
			
		||||
//            parameters.setRequired(requiredValues);
 | 
			
		||||
//            functionDetail.setParameters(parameters);
 | 
			
		||||
//            def.setFunction(functionDetail);
 | 
			
		||||
//
 | 
			
		||||
//            tools.add(def);
 | 
			
		||||
//            return this;
 | 
			
		||||
//        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user