Extends ToolSpec to have PromptDef for ChatRequests

This commit is contained in:
Markus Klenke 2024-12-04 08:45:00 +01:00 committed by Markus Klenke
parent ff3344616c
commit 903a8176cd
4 changed files with 32 additions and 11 deletions

View File

@ -774,11 +774,15 @@ public class OllamaAPI {
} else { } else {
result = requestCaller.callSync(request); result = requestCaller.callSync(request);
} }
//add registered Tools to Request
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
} }
public void registerTool(Tools.ToolSpecification toolSpecification) { public void registerTool(Tools.ToolSpecification toolSpecification) {
toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition()); toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
} }
/** /**
@ -871,7 +875,7 @@ public class OllamaAPI {
try { try {
String methodName = toolFunctionCallSpec.getName(); String methodName = toolFunctionCallSpec.getName();
Map<String, Object> arguments = toolFunctionCallSpec.getArguments(); Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
ToolFunction function = toolRegistry.getFunction(methodName); ToolFunction function = toolRegistry.getToolFunction(methodName);
if (verbose) { if (verbose) {
logger.debug("Invoking function {} with arguments {}", methodName, arguments); logger.debug("Invoking function {} with arguments {}", methodName, arguments);
} }

View File

@ -3,6 +3,7 @@ package io.github.ollama4j.models.chat;
import java.util.List; import java.util.List;
import io.github.ollama4j.models.request.OllamaCommonRequest; import io.github.ollama4j.models.request.OllamaCommonRequest;
import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.OllamaRequestBody; import io.github.ollama4j.utils.OllamaRequestBody;
import lombok.Getter; import lombok.Getter;
@ -21,6 +22,8 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ
private List<OllamaChatMessage> messages; private List<OllamaChatMessage> messages;
private List<Tools.PromptFuncDefinition> tools;
public OllamaChatRequest() {} public OllamaChatRequest() {}
public OllamaChatRequest(String model, List<OllamaChatMessage> messages) { public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {

View File

@ -4,13 +4,14 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
public class ToolRegistry { public class ToolRegistry {
private final Map<String, ToolFunction> functionMap = new HashMap<>(); private final Map<String, Tools.ToolSpecification> tools = new HashMap<>();
public ToolFunction getFunction(String name) { public ToolFunction getToolFunction(String name) {
return functionMap.get(name); final Tools.ToolSpecification toolSpecification = tools.get(name);
return toolSpecification !=null ? toolSpecification.getToolFunction() : null ;
} }
public void addFunction(String name, ToolFunction function) { public void addTool (String name, Tools.ToolSpecification specification) {
functionMap.put(name, function); tools.put(name, specification);
} }
} }

View File

@ -6,8 +6,10 @@ import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
@ -20,17 +22,23 @@ public class Tools {
public static class ToolSpecification { public static class ToolSpecification {
private String functionName; private String functionName;
private String functionDescription; private String functionDescription;
private Map<String, PromptFuncDefinition.Property> properties; private PromptFuncDefinition toolPrompt;
private ToolFunction toolDefinition; private ToolFunction toolFunction;
} }
@Data @Data
@JsonIgnoreProperties(ignoreUnknown = true) @JsonIgnoreProperties(ignoreUnknown = true)
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class PromptFuncDefinition { public static class PromptFuncDefinition {
private String type; private String type;
private PromptFuncSpec function; private PromptFuncSpec function;
@Data @Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class PromptFuncSpec { public static class PromptFuncSpec {
private String name; private String name;
private String description; private String description;
@ -38,6 +46,9 @@ public class Tools {
} }
@Data @Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class Parameters { public static class Parameters {
private String type; private String type;
private Map<String, Property> properties; private Map<String, Property> properties;
@ -46,6 +57,8 @@ public class Tools {
@Data @Data
@Builder @Builder
@NoArgsConstructor
@AllArgsConstructor
public static class Property { public static class Property {
private String type; private String type;
private String description; private String description;
@ -94,10 +107,10 @@ public class Tools {
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters(); PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
parameters.setType("object"); parameters.setType("object");
parameters.setProperties(spec.getProperties()); parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties());
List<String> requiredValues = new ArrayList<>(); List<String> requiredValues = new ArrayList<>();
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) { for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) {
if (p.getValue().isRequired()) { if (p.getValue().isRequired()) {
requiredValues.add(p.getKey()); requiredValues.add(p.getKey());
} }