Extends ToolSpec to have PromptDef for ChatRequests

This commit is contained in:
Markus Klenke 2024-12-04 08:45:00 +01:00 committed by Markus Klenke
parent ff3344616c
commit 903a8176cd
4 changed files with 32 additions and 11 deletions

View File

@ -774,11 +774,15 @@ public class OllamaAPI {
} else {
result = requestCaller.callSync(request);
}
//add registered Tools to Request
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
}
public void registerTool(Tools.ToolSpecification toolSpecification) {
toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
}
/**
@ -871,7 +875,7 @@ public class OllamaAPI {
try {
String methodName = toolFunctionCallSpec.getName();
Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
ToolFunction function = toolRegistry.getFunction(methodName);
ToolFunction function = toolRegistry.getToolFunction(methodName);
if (verbose) {
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
}

View File

@ -3,6 +3,7 @@ package io.github.ollama4j.models.chat;
import java.util.List;
import io.github.ollama4j.models.request.OllamaCommonRequest;
import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.OllamaRequestBody;
import lombok.Getter;
@ -21,6 +22,8 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ
private List<OllamaChatMessage> messages;
private List<Tools.PromptFuncDefinition> tools;
public OllamaChatRequest() {}
public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {

View File

@ -4,13 +4,14 @@ import java.util.HashMap;
import java.util.Map;
public class ToolRegistry {
private final Map<String, ToolFunction> functionMap = new HashMap<>();
private final Map<String, Tools.ToolSpecification> tools = new HashMap<>();
public ToolFunction getFunction(String name) {
return functionMap.get(name);
public ToolFunction getToolFunction(String name) {
final Tools.ToolSpecification toolSpecification = tools.get(name);
return toolSpecification !=null ? toolSpecification.getToolFunction() : null ;
}
public void addFunction(String name, ToolFunction function) {
functionMap.put(name, function);
public void addTool (String name, Tools.ToolSpecification specification) {
tools.put(name, specification);
}
}

View File

@ -6,8 +6,10 @@ import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.utils.Utils;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.HashMap;
@ -20,17 +22,23 @@ public class Tools {
public static class ToolSpecification {
private String functionName;
private String functionDescription;
private Map<String, PromptFuncDefinition.Property> properties;
private ToolFunction toolDefinition;
private PromptFuncDefinition toolPrompt;
private ToolFunction toolFunction;
}
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class PromptFuncDefinition {
private String type;
private PromptFuncSpec function;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class PromptFuncSpec {
private String name;
private String description;
@ -38,6 +46,9 @@ public class Tools {
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class Parameters {
private String type;
private Map<String, Property> properties;
@ -46,6 +57,8 @@ public class Tools {
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class Property {
private String type;
private String description;
@ -94,10 +107,10 @@ public class Tools {
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
parameters.setType("object");
parameters.setProperties(spec.getProperties());
parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties());
List<String> requiredValues = new ArrayList<>();
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) {
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) {
if (p.getValue().isRequired()) {
requiredValues.add(p.getKey());
}