mirror of
				https://github.com/amithkoujalgi/ollama4j.git
				synced 2025-11-04 10:30:41 +01:00 
			
		
		
		
	Added support for tools/function calling - specifically for Mistral's latest model.
This commit is contained in:
		@@ -10,6 +10,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingRe
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.request.*;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.tools.*;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
 | 
			
		||||
import org.slf4j.Logger;
 | 
			
		||||
@@ -25,9 +26,7 @@ import java.net.http.HttpResponse;
 | 
			
		||||
import java.nio.charset.StandardCharsets;
 | 
			
		||||
import java.nio.file.Files;
 | 
			
		||||
import java.time.Duration;
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.Base64;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.*;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The base Ollama API class.
 | 
			
		||||
@@ -339,6 +338,7 @@ public class OllamaAPI {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Generate response for a question to a model running on Ollama server. This is a sync/blocking
 | 
			
		||||
     * call.
 | 
			
		||||
@@ -351,9 +351,10 @@ public class OllamaAPI {
 | 
			
		||||
     * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
 | 
			
		||||
     * @return OllamaResult that includes response text and time taken for response
 | 
			
		||||
     */
 | 
			
		||||
    public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler)
 | 
			
		||||
    public OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler)
 | 
			
		||||
            throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
 | 
			
		||||
        ollamaRequestModel.setRaw(raw);
 | 
			
		||||
        ollamaRequestModel.setOptions(options.getOptionsMap());
 | 
			
		||||
        return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
 | 
			
		||||
    }
 | 
			
		||||
@@ -361,13 +362,37 @@ public class OllamaAPI {
 | 
			
		||||
    /**
 | 
			
		||||
     * Convenience method to call Ollama API without streaming responses.
 | 
			
		||||
     * <p>
 | 
			
		||||
     * Uses {@link #generate(String, String, Options, OllamaStreamHandler)}
 | 
			
		||||
     * Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
 | 
			
		||||
     *
 | 
			
		||||
     * @param model   Model to use
 | 
			
		||||
     * @param prompt  Prompt text
 | 
			
		||||
     * @param raw     In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context.
 | 
			
		||||
     * @param options Additional Options
 | 
			
		||||
     * @return OllamaResult
 | 
			
		||||
     */
 | 
			
		||||
    public OllamaResult generate(String model, String prompt, Options options)
 | 
			
		||||
    public OllamaResult generate(String model, String prompt, boolean raw, Options options)
 | 
			
		||||
            throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        return generate(model, prompt, options, null);
 | 
			
		||||
        return generate(model, prompt, raw, options, null);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options)
 | 
			
		||||
            throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        OllamaToolsResult toolResult = new OllamaToolsResult();
 | 
			
		||||
        Map<ToolDef, Object> toolResults = new HashMap<>();
 | 
			
		||||
 | 
			
		||||
        OllamaResult result = generate(model, prompt, raw, options, null);
 | 
			
		||||
        toolResult.setModelResult(result);
 | 
			
		||||
 | 
			
		||||
        List<ToolDef> toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class));
 | 
			
		||||
        for (ToolDef toolDef : toolDefs) {
 | 
			
		||||
            toolResults.put(toolDef, invokeTool(toolDef));
 | 
			
		||||
        }
 | 
			
		||||
        toolResult.setToolResults(toolResults);
 | 
			
		||||
        return toolResult;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Generate response for a question to a model running on Ollama server and get a callback handle
 | 
			
		||||
     * that can be used to check for status and get the response from the model later. This would be
 | 
			
		||||
@@ -377,9 +402,9 @@ public class OllamaAPI {
 | 
			
		||||
     * @param prompt the prompt/question text
 | 
			
		||||
     * @return the ollama async result callback handle
 | 
			
		||||
     */
 | 
			
		||||
    public OllamaAsyncResultCallback generateAsync(String model, String prompt) {
 | 
			
		||||
    public OllamaAsyncResultCallback generateAsync(String model, String prompt, boolean raw) {
 | 
			
		||||
        OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
 | 
			
		||||
 | 
			
		||||
        ollamaRequestModel.setRaw(raw);
 | 
			
		||||
        URI uri = URI.create(this.host + "/api/generate");
 | 
			
		||||
        OllamaAsyncResultCallback ollamaAsyncResultCallback =
 | 
			
		||||
                new OllamaAsyncResultCallback(
 | 
			
		||||
@@ -576,4 +601,24 @@ public class OllamaAPI {
 | 
			
		||||
    private boolean isBasicAuthCredentialsSet() {
 | 
			
		||||
        return basicAuth != null;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public void registerTool(MistralTools.ToolSpecification toolSpecification) {
 | 
			
		||||
        ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Object invokeTool(ToolDef toolDef) {
 | 
			
		||||
        try {
 | 
			
		||||
            String methodName = toolDef.getName();
 | 
			
		||||
            Map<String, Object> arguments = toolDef.getArguments();
 | 
			
		||||
            DynamicFunction function = ToolRegistry.getFunction(methodName);
 | 
			
		||||
            if (function == null) {
 | 
			
		||||
                throw new IllegalArgumentException("No such tool: " + methodName);
 | 
			
		||||
            }
 | 
			
		||||
            return function.apply(arguments);
 | 
			
		||||
        } catch (Exception e) {
 | 
			
		||||
            e.printStackTrace();
 | 
			
		||||
            return "Error calling tool: " + e.getMessage();
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,5 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.models.request;
 | 
			
		||||
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import org.slf4j.Logger;
 | 
			
		||||
import org.slf4j.LoggerFactory;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
 | 
			
		||||
@@ -13,15 +9,19 @@ import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRespo
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
 | 
			
		||||
import org.slf4j.Logger;
 | 
			
		||||
import org.slf4j.LoggerFactory;
 | 
			
		||||
 | 
			
		||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
 | 
			
		||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
 | 
			
		||||
 | 
			
		||||
    private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
 | 
			
		||||
 | 
			
		||||
    private OllamaGenerateStreamObserver streamObserver;
 | 
			
		||||
 | 
			
		||||
    public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
 | 
			
		||||
        super(host, basicAuth, requestTimeoutSeconds, verbose);   
 | 
			
		||||
        super(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
@@ -31,24 +31,22 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
 | 
			
		||||
                try {
 | 
			
		||||
                    OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
 | 
			
		||||
                    responseBuffer.append(ollamaResponseModel.getResponse());
 | 
			
		||||
                    if(streamObserver != null) {
 | 
			
		||||
                        streamObserver.notify(ollamaResponseModel);
 | 
			
		||||
                    }
 | 
			
		||||
                    return ollamaResponseModel.isDone();
 | 
			
		||||
                } catch (JsonProcessingException e) {
 | 
			
		||||
                    LOG.error("Error parsing the Ollama chat response!",e);
 | 
			
		||||
                    return true;
 | 
			
		||||
                }         
 | 
			
		||||
        try {
 | 
			
		||||
            OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
 | 
			
		||||
            responseBuffer.append(ollamaResponseModel.getResponse());
 | 
			
		||||
            if (streamObserver != null) {
 | 
			
		||||
                streamObserver.notify(ollamaResponseModel);
 | 
			
		||||
            }
 | 
			
		||||
            return ollamaResponseModel.isDone();
 | 
			
		||||
        } catch (JsonProcessingException e) {
 | 
			
		||||
            LOG.error("Error parsing the Ollama chat response!", e);
 | 
			
		||||
            return true;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
 | 
			
		||||
        throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
    streamObserver = new OllamaGenerateStreamObserver(streamHandler);
 | 
			
		||||
    return super.callSync(body);
 | 
			
		||||
            throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        streamObserver = new OllamaGenerateStreamObserver(streamHandler);
 | 
			
		||||
        return super.callSync(body);
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,8 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.tools;
 | 
			
		||||
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@FunctionalInterface
 | 
			
		||||
public interface DynamicFunction {
 | 
			
		||||
    Object apply(Map<String, Object> arguments);
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,139 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.tools;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonIgnore;
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonInclude;
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
public class MistralTools {
 | 
			
		||||
    @Data
 | 
			
		||||
    @Builder
 | 
			
		||||
    public static class ToolSpecification {
 | 
			
		||||
        private String functionName;
 | 
			
		||||
        private String functionDesc;
 | 
			
		||||
        private Map<String, PromptFuncDefinition.Property> props;
 | 
			
		||||
        private DynamicFunction toolDefinition;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Data
 | 
			
		||||
    @JsonIgnoreProperties(ignoreUnknown = true)
 | 
			
		||||
    public static class PromptFuncDefinition {
 | 
			
		||||
        private String type;
 | 
			
		||||
        private PromptFuncSpec function;
 | 
			
		||||
 | 
			
		||||
        @Data
 | 
			
		||||
        public static class PromptFuncSpec {
 | 
			
		||||
            private String name;
 | 
			
		||||
            private String description;
 | 
			
		||||
            private Parameters parameters;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Data
 | 
			
		||||
        public static class Parameters {
 | 
			
		||||
            private String type;
 | 
			
		||||
            private Map<String, Property> properties;
 | 
			
		||||
            private List<String> required;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Data
 | 
			
		||||
        @Builder
 | 
			
		||||
        public static class Property {
 | 
			
		||||
            private String type;
 | 
			
		||||
            private String description;
 | 
			
		||||
            @JsonProperty("enum")
 | 
			
		||||
            @JsonInclude(JsonInclude.Include.NON_NULL)
 | 
			
		||||
            private List<String> enumValues;
 | 
			
		||||
            @JsonIgnore
 | 
			
		||||
            private boolean required;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static class PropsBuilder {
 | 
			
		||||
        private final Map<String, PromptFuncDefinition.Property> props = new HashMap<>();
 | 
			
		||||
 | 
			
		||||
        public PropsBuilder withProperty(String key, PromptFuncDefinition.Property property) {
 | 
			
		||||
            props.put(key, property);
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        public Map<String, PromptFuncDefinition.Property> build() {
 | 
			
		||||
            return props;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static class PromptBuilder {
 | 
			
		||||
        private final List<PromptFuncDefinition> tools = new ArrayList<>();
 | 
			
		||||
 | 
			
		||||
        private String promptText;
 | 
			
		||||
 | 
			
		||||
        public String build() throws JsonProcessingException {
 | 
			
		||||
            return "[AVAILABLE_TOOLS] " + Utils.getObjectMapper().writeValueAsString(tools) + "[/AVAILABLE_TOOLS][INST] " + promptText + " [/INST]";
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        public PromptBuilder withPrompt(String prompt) throws JsonProcessingException {
 | 
			
		||||
            promptText = prompt;
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        public PromptBuilder withToolSpecification(ToolSpecification spec) {
 | 
			
		||||
            PromptFuncDefinition def = new PromptFuncDefinition();
 | 
			
		||||
            def.setType("function");
 | 
			
		||||
 | 
			
		||||
            PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
 | 
			
		||||
            functionDetail.setName(spec.getFunctionName());
 | 
			
		||||
            functionDetail.setDescription(spec.getFunctionDesc());
 | 
			
		||||
 | 
			
		||||
            PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
 | 
			
		||||
            parameters.setType("object");
 | 
			
		||||
            parameters.setProperties(spec.getProps());
 | 
			
		||||
 | 
			
		||||
            List<String> requiredValues = new ArrayList<>();
 | 
			
		||||
            for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProps().entrySet()) {
 | 
			
		||||
                if (p.getValue().isRequired()) {
 | 
			
		||||
                    requiredValues.add(p.getKey());
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            parameters.setRequired(requiredValues);
 | 
			
		||||
            functionDetail.setParameters(parameters);
 | 
			
		||||
            def.setFunction(functionDetail);
 | 
			
		||||
 | 
			
		||||
            tools.add(def);
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
//
 | 
			
		||||
//        public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map<String, PromptFuncDefinition.Property> props) {
 | 
			
		||||
//            PromptFuncDefinition def = new PromptFuncDefinition();
 | 
			
		||||
//            def.setType("function");
 | 
			
		||||
//
 | 
			
		||||
//            PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
 | 
			
		||||
//            functionDetail.setName(functionName);
 | 
			
		||||
//            functionDetail.setDescription(functionDesc);
 | 
			
		||||
//
 | 
			
		||||
//            PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
 | 
			
		||||
//            parameters.setType("object");
 | 
			
		||||
//            parameters.setProperties(props);
 | 
			
		||||
//
 | 
			
		||||
//            List<String> requiredValues = new ArrayList<>();
 | 
			
		||||
//            for (Map.Entry<String, PromptFuncDefinition.Property> p : props.entrySet()) {
 | 
			
		||||
//                if (p.getValue().isRequired()) {
 | 
			
		||||
//                    requiredValues.add(p.getKey());
 | 
			
		||||
//                }
 | 
			
		||||
//            }
 | 
			
		||||
//            parameters.setRequired(requiredValues);
 | 
			
		||||
//            functionDetail.setParameters(parameters);
 | 
			
		||||
//            def.setFunction(functionDetail);
 | 
			
		||||
//
 | 
			
		||||
//            tools.add(def);
 | 
			
		||||
//            return this;
 | 
			
		||||
//        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,16 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.tools;
 | 
			
		||||
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
public class OllamaToolsResult {
 | 
			
		||||
    private OllamaResult modelResult;
 | 
			
		||||
    private Map<ToolDef, Object> toolResults;
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,18 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.tools;
 | 
			
		||||
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class ToolDef {
 | 
			
		||||
 | 
			
		||||
    private String name;
 | 
			
		||||
    private Map<String, Object> arguments;
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -0,0 +1,17 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.tools;
 | 
			
		||||
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
public class ToolRegistry {
 | 
			
		||||
    private static final Map<String, DynamicFunction> functionMap = new HashMap<>();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public static DynamicFunction getFunction(String name) {
 | 
			
		||||
        return functionMap.get(name);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static void addFunction(String name, DynamicFunction function) {
 | 
			
		||||
        functionMap.put(name, function);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user