mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 11:57:12 +02:00
Extends ToolSpec to have PromptDef for ChatRequests
This commit is contained in:
parent
ff3344616c
commit
903a8176cd
@ -774,11 +774,15 @@ public class OllamaAPI {
|
|||||||
} else {
|
} else {
|
||||||
result = requestCaller.callSync(request);
|
result = requestCaller.callSync(request);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//add registered Tools to Request
|
||||||
|
|
||||||
|
|
||||||
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
|
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void registerTool(Tools.ToolSpecification toolSpecification) {
|
public void registerTool(Tools.ToolSpecification toolSpecification) {
|
||||||
toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
|
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -871,7 +875,7 @@ public class OllamaAPI {
|
|||||||
try {
|
try {
|
||||||
String methodName = toolFunctionCallSpec.getName();
|
String methodName = toolFunctionCallSpec.getName();
|
||||||
Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
|
Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
|
||||||
ToolFunction function = toolRegistry.getFunction(methodName);
|
ToolFunction function = toolRegistry.getToolFunction(methodName);
|
||||||
if (verbose) {
|
if (verbose) {
|
||||||
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
|
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package io.github.ollama4j.models.chat;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import io.github.ollama4j.models.request.OllamaCommonRequest;
|
import io.github.ollama4j.models.request.OllamaCommonRequest;
|
||||||
|
import io.github.ollama4j.tools.Tools;
|
||||||
import io.github.ollama4j.utils.OllamaRequestBody;
|
import io.github.ollama4j.utils.OllamaRequestBody;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
@ -21,6 +22,8 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ
|
|||||||
|
|
||||||
private List<OllamaChatMessage> messages;
|
private List<OllamaChatMessage> messages;
|
||||||
|
|
||||||
|
private List<Tools.PromptFuncDefinition> tools;
|
||||||
|
|
||||||
public OllamaChatRequest() {}
|
public OllamaChatRequest() {}
|
||||||
|
|
||||||
public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
|
public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
|
||||||
|
@ -4,13 +4,14 @@ import java.util.HashMap;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public class ToolRegistry {
|
public class ToolRegistry {
|
||||||
private final Map<String, ToolFunction> functionMap = new HashMap<>();
|
private final Map<String, Tools.ToolSpecification> tools = new HashMap<>();
|
||||||
|
|
||||||
public ToolFunction getFunction(String name) {
|
public ToolFunction getToolFunction(String name) {
|
||||||
return functionMap.get(name);
|
final Tools.ToolSpecification toolSpecification = tools.get(name);
|
||||||
|
return toolSpecification !=null ? toolSpecification.getToolFunction() : null ;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addFunction(String name, ToolFunction function) {
|
public void addTool (String name, Tools.ToolSpecification specification) {
|
||||||
functionMap.put(name, function);
|
tools.put(name, specification);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,8 +6,10 @@ import com.fasterxml.jackson.annotation.JsonInclude;
|
|||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import io.github.ollama4j.utils.Utils;
|
import io.github.ollama4j.utils.Utils;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@ -20,17 +22,23 @@ public class Tools {
|
|||||||
public static class ToolSpecification {
|
public static class ToolSpecification {
|
||||||
private String functionName;
|
private String functionName;
|
||||||
private String functionDescription;
|
private String functionDescription;
|
||||||
private Map<String, PromptFuncDefinition.Property> properties;
|
private PromptFuncDefinition toolPrompt;
|
||||||
private ToolFunction toolDefinition;
|
private ToolFunction toolFunction;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
public static class PromptFuncDefinition {
|
public static class PromptFuncDefinition {
|
||||||
private String type;
|
private String type;
|
||||||
private PromptFuncSpec function;
|
private PromptFuncSpec function;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
public static class PromptFuncSpec {
|
public static class PromptFuncSpec {
|
||||||
private String name;
|
private String name;
|
||||||
private String description;
|
private String description;
|
||||||
@ -38,6 +46,9 @@ public class Tools {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
public static class Parameters {
|
public static class Parameters {
|
||||||
private String type;
|
private String type;
|
||||||
private Map<String, Property> properties;
|
private Map<String, Property> properties;
|
||||||
@ -46,6 +57,8 @@ public class Tools {
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
public static class Property {
|
public static class Property {
|
||||||
private String type;
|
private String type;
|
||||||
private String description;
|
private String description;
|
||||||
@ -94,10 +107,10 @@ public class Tools {
|
|||||||
|
|
||||||
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
|
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
|
||||||
parameters.setType("object");
|
parameters.setType("object");
|
||||||
parameters.setProperties(spec.getProperties());
|
parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties());
|
||||||
|
|
||||||
List<String> requiredValues = new ArrayList<>();
|
List<String> requiredValues = new ArrayList<>();
|
||||||
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) {
|
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) {
|
||||||
if (p.getValue().isRequired()) {
|
if (p.getValue().isRequired()) {
|
||||||
requiredValues.add(p.getKey());
|
requiredValues.add(p.getKey());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user