forked from Mirror/ollama4j
Added support for tools/function calling - specifically for Mistral's latest model.
This commit is contained in:
@@ -10,6 +10,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingRe
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.request.*;
|
||||
import io.github.amithkoujalgi.ollama4j.core.tools.*;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||
import org.slf4j.Logger;
|
||||
@@ -25,9 +26,7 @@ import java.net.http.HttpResponse;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* The base Ollama API class.
|
||||
@@ -339,6 +338,7 @@ public class OllamaAPI {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generate response for a question to a model running on Ollama server. This is a sync/blocking
|
||||
* call.
|
||||
@@ -351,9 +351,10 @@ public class OllamaAPI {
|
||||
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
|
||||
* @return OllamaResult that includes response text and time taken for response
|
||||
*/
|
||||
public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler)
|
||||
public OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
|
||||
ollamaRequestModel.setRaw(raw);
|
||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
|
||||
}
|
||||
@@ -361,13 +362,37 @@ public class OllamaAPI {
|
||||
/**
|
||||
* Convenience method to call Ollama API without streaming responses.
|
||||
* <p>
|
||||
* Uses {@link #generate(String, String, Options, OllamaStreamHandler)}
|
||||
* Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
|
||||
*
|
||||
* @param model Model to use
|
||||
* @param prompt Prompt text
|
||||
* @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context.
|
||||
* @param options Additional Options
|
||||
* @return OllamaResult
|
||||
*/
|
||||
public OllamaResult generate(String model, String prompt, Options options)
|
||||
public OllamaResult generate(String model, String prompt, boolean raw, Options options)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
return generate(model, prompt, options, null);
|
||||
return generate(model, prompt, raw, options, null);
|
||||
}
|
||||
|
||||
|
||||
public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaToolsResult toolResult = new OllamaToolsResult();
|
||||
Map<ToolDef, Object> toolResults = new HashMap<>();
|
||||
|
||||
OllamaResult result = generate(model, prompt, raw, options, null);
|
||||
toolResult.setModelResult(result);
|
||||
|
||||
List<ToolDef> toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class));
|
||||
for (ToolDef toolDef : toolDefs) {
|
||||
toolResults.put(toolDef, invokeTool(toolDef));
|
||||
}
|
||||
toolResult.setToolResults(toolResults);
|
||||
return toolResult;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generate response for a question to a model running on Ollama server and get a callback handle
|
||||
* that can be used to check for status and get the response from the model later. This would be
|
||||
@@ -377,9 +402,9 @@ public class OllamaAPI {
|
||||
* @param prompt the prompt/question text
|
||||
* @return the ollama async result callback handle
|
||||
*/
|
||||
public OllamaAsyncResultCallback generateAsync(String model, String prompt) {
|
||||
public OllamaAsyncResultCallback generateAsync(String model, String prompt, boolean raw) {
|
||||
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
|
||||
|
||||
ollamaRequestModel.setRaw(raw);
|
||||
URI uri = URI.create(this.host + "/api/generate");
|
||||
OllamaAsyncResultCallback ollamaAsyncResultCallback =
|
||||
new OllamaAsyncResultCallback(
|
||||
@@ -576,4 +601,24 @@ public class OllamaAPI {
|
||||
private boolean isBasicAuthCredentialsSet() {
|
||||
return basicAuth != null;
|
||||
}
|
||||
|
||||
|
||||
public void registerTool(MistralTools.ToolSpecification toolSpecification) {
|
||||
ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
|
||||
}
|
||||
|
||||
private Object invokeTool(ToolDef toolDef) {
|
||||
try {
|
||||
String methodName = toolDef.getName();
|
||||
Map<String, Object> arguments = toolDef.getArguments();
|
||||
DynamicFunction function = ToolRegistry.getFunction(methodName);
|
||||
if (function == null) {
|
||||
throw new IllegalArgumentException("No such tool: " + methodName);
|
||||
}
|
||||
return function.apply(arguments);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return "Error calling tool: " + e.getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.models.request;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
|
||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||
@@ -13,15 +9,19 @@ import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRespo
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
|
||||
import java.io.IOException;
|
||||
|
||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
|
||||
|
||||
private OllamaGenerateStreamObserver streamObserver;
|
||||
|
||||
public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
||||
super(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||
super(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -31,24 +31,22 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
|
||||
|
||||
@Override
|
||||
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
|
||||
try {
|
||||
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
|
||||
responseBuffer.append(ollamaResponseModel.getResponse());
|
||||
if(streamObserver != null) {
|
||||
streamObserver.notify(ollamaResponseModel);
|
||||
}
|
||||
return ollamaResponseModel.isDone();
|
||||
} catch (JsonProcessingException e) {
|
||||
LOG.error("Error parsing the Ollama chat response!",e);
|
||||
return true;
|
||||
}
|
||||
try {
|
||||
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
|
||||
responseBuffer.append(ollamaResponseModel.getResponse());
|
||||
if (streamObserver != null) {
|
||||
streamObserver.notify(ollamaResponseModel);
|
||||
}
|
||||
return ollamaResponseModel.isDone();
|
||||
} catch (JsonProcessingException e) {
|
||||
LOG.error("Error parsing the Ollama chat response!", e);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
|
||||
return super.callSync(body);
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
|
||||
return super.callSync(body);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.tools;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@FunctionalInterface
|
||||
public interface DynamicFunction {
|
||||
Object apply(Map<String, Object> arguments);
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.tools;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class MistralTools {
|
||||
@Data
|
||||
@Builder
|
||||
public static class ToolSpecification {
|
||||
private String functionName;
|
||||
private String functionDesc;
|
||||
private Map<String, PromptFuncDefinition.Property> props;
|
||||
private DynamicFunction toolDefinition;
|
||||
}
|
||||
|
||||
@Data
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
public static class PromptFuncDefinition {
|
||||
private String type;
|
||||
private PromptFuncSpec function;
|
||||
|
||||
@Data
|
||||
public static class PromptFuncSpec {
|
||||
private String name;
|
||||
private String description;
|
||||
private Parameters parameters;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class Parameters {
|
||||
private String type;
|
||||
private Map<String, Property> properties;
|
||||
private List<String> required;
|
||||
}
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public static class Property {
|
||||
private String type;
|
||||
private String description;
|
||||
@JsonProperty("enum")
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
private List<String> enumValues;
|
||||
@JsonIgnore
|
||||
private boolean required;
|
||||
}
|
||||
}
|
||||
|
||||
public static class PropsBuilder {
|
||||
private final Map<String, PromptFuncDefinition.Property> props = new HashMap<>();
|
||||
|
||||
public PropsBuilder withProperty(String key, PromptFuncDefinition.Property property) {
|
||||
props.put(key, property);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Map<String, PromptFuncDefinition.Property> build() {
|
||||
return props;
|
||||
}
|
||||
}
|
||||
|
||||
public static class PromptBuilder {
|
||||
private final List<PromptFuncDefinition> tools = new ArrayList<>();
|
||||
|
||||
private String promptText;
|
||||
|
||||
public String build() throws JsonProcessingException {
|
||||
return "[AVAILABLE_TOOLS] " + Utils.getObjectMapper().writeValueAsString(tools) + "[/AVAILABLE_TOOLS][INST] " + promptText + " [/INST]";
|
||||
}
|
||||
|
||||
public PromptBuilder withPrompt(String prompt) throws JsonProcessingException {
|
||||
promptText = prompt;
|
||||
return this;
|
||||
}
|
||||
|
||||
public PromptBuilder withToolSpecification(ToolSpecification spec) {
|
||||
PromptFuncDefinition def = new PromptFuncDefinition();
|
||||
def.setType("function");
|
||||
|
||||
PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
|
||||
functionDetail.setName(spec.getFunctionName());
|
||||
functionDetail.setDescription(spec.getFunctionDesc());
|
||||
|
||||
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
|
||||
parameters.setType("object");
|
||||
parameters.setProperties(spec.getProps());
|
||||
|
||||
List<String> requiredValues = new ArrayList<>();
|
||||
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProps().entrySet()) {
|
||||
if (p.getValue().isRequired()) {
|
||||
requiredValues.add(p.getKey());
|
||||
}
|
||||
}
|
||||
parameters.setRequired(requiredValues);
|
||||
functionDetail.setParameters(parameters);
|
||||
def.setFunction(functionDetail);
|
||||
|
||||
tools.add(def);
|
||||
return this;
|
||||
}
|
||||
//
|
||||
// public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map<String, PromptFuncDefinition.Property> props) {
|
||||
// PromptFuncDefinition def = new PromptFuncDefinition();
|
||||
// def.setType("function");
|
||||
//
|
||||
// PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
|
||||
// functionDetail.setName(functionName);
|
||||
// functionDetail.setDescription(functionDesc);
|
||||
//
|
||||
// PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
|
||||
// parameters.setType("object");
|
||||
// parameters.setProperties(props);
|
||||
//
|
||||
// List<String> requiredValues = new ArrayList<>();
|
||||
// for (Map.Entry<String, PromptFuncDefinition.Property> p : props.entrySet()) {
|
||||
// if (p.getValue().isRequired()) {
|
||||
// requiredValues.add(p.getKey());
|
||||
// }
|
||||
// }
|
||||
// parameters.setRequired(requiredValues);
|
||||
// functionDetail.setParameters(parameters);
|
||||
// def.setFunction(functionDetail);
|
||||
//
|
||||
// tools.add(def);
|
||||
// return this;
|
||||
// }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.tools;
|
||||
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class OllamaToolsResult {
|
||||
private OllamaResult modelResult;
|
||||
private Map<ToolDef, Object> toolResults;
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.tools;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class ToolDef {
|
||||
|
||||
private String name;
|
||||
private Map<String, Object> arguments;
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.tools;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class ToolRegistry {
|
||||
private static final Map<String, DynamicFunction> functionMap = new HashMap<>();
|
||||
|
||||
|
||||
public static DynamicFunction getFunction(String name) {
|
||||
return functionMap.get(name);
|
||||
}
|
||||
|
||||
public static void addFunction(String name, DynamicFunction function) {
|
||||
functionMap.put(name, function);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user