forked from Mirror/ollama4j
		
	Extends ToolSpec to have PromptDef for ChatRequests
This commit is contained in:
		
				
					committed by
					
						
						Markus Klenke
					
				
			
			
				
	
			
			
			
						parent
						
							ff3344616c
						
					
				
				
					commit
					903a8176cd
				
			@@ -774,11 +774,15 @@ public class OllamaAPI {
 | 
			
		||||
        } else {
 | 
			
		||||
            result = requestCaller.callSync(request);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        //add registered Tools to Request
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public void registerTool(Tools.ToolSpecification toolSpecification) {
 | 
			
		||||
        toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
 | 
			
		||||
        toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
@@ -871,7 +875,7 @@ public class OllamaAPI {
 | 
			
		||||
        try {
 | 
			
		||||
            String methodName = toolFunctionCallSpec.getName();
 | 
			
		||||
            Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
 | 
			
		||||
            ToolFunction function = toolRegistry.getFunction(methodName);
 | 
			
		||||
            ToolFunction function = toolRegistry.getToolFunction(methodName);
 | 
			
		||||
            if (verbose) {
 | 
			
		||||
                logger.debug("Invoking function {} with arguments {}", methodName, arguments);
 | 
			
		||||
            }
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package io.github.ollama4j.models.chat;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
import io.github.ollama4j.models.request.OllamaCommonRequest;
 | 
			
		||||
import io.github.ollama4j.tools.Tools;
 | 
			
		||||
import io.github.ollama4j.utils.OllamaRequestBody;
 | 
			
		||||
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
@@ -21,6 +22,8 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ
 | 
			
		||||
 | 
			
		||||
  private List<OllamaChatMessage> messages;
 | 
			
		||||
 | 
			
		||||
  private List<Tools.PromptFuncDefinition> tools;
 | 
			
		||||
 | 
			
		||||
  public OllamaChatRequest() {}
 | 
			
		||||
 | 
			
		||||
  public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
 | 
			
		||||
 
 | 
			
		||||
@@ -4,13 +4,14 @@ import java.util.HashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
public class ToolRegistry {
 | 
			
		||||
    private final Map<String, ToolFunction> functionMap = new HashMap<>();
 | 
			
		||||
    private final Map<String, Tools.ToolSpecification> tools = new HashMap<>();
 | 
			
		||||
 | 
			
		||||
    public ToolFunction getFunction(String name) {
 | 
			
		||||
        return functionMap.get(name);
 | 
			
		||||
    public ToolFunction getToolFunction(String name) {
 | 
			
		||||
        final Tools.ToolSpecification toolSpecification = tools.get(name);
 | 
			
		||||
        return toolSpecification !=null ? toolSpecification.getToolFunction() : null ;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public void addFunction(String name, ToolFunction function) {
 | 
			
		||||
        functionMap.put(name, function);
 | 
			
		||||
    public void addTool (String name, Tools.ToolSpecification specification) {
 | 
			
		||||
        tools.put(name, specification);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,8 +6,10 @@ import com.fasterxml.jackson.annotation.JsonInclude;
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import io.github.ollama4j.utils.Utils;
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
@@ -20,17 +22,23 @@ public class Tools {
 | 
			
		||||
    public static class ToolSpecification {
 | 
			
		||||
        private String functionName;
 | 
			
		||||
        private String functionDescription;
 | 
			
		||||
        private Map<String, PromptFuncDefinition.Property> properties;
 | 
			
		||||
        private ToolFunction toolDefinition;
 | 
			
		||||
        private PromptFuncDefinition toolPrompt;
 | 
			
		||||
        private ToolFunction toolFunction;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Data
 | 
			
		||||
    @JsonIgnoreProperties(ignoreUnknown = true)
 | 
			
		||||
    @Builder
 | 
			
		||||
    @NoArgsConstructor
 | 
			
		||||
    @AllArgsConstructor
 | 
			
		||||
    public static class PromptFuncDefinition {
 | 
			
		||||
        private String type;
 | 
			
		||||
        private PromptFuncSpec function;
 | 
			
		||||
 | 
			
		||||
        @Data
 | 
			
		||||
        @Builder
 | 
			
		||||
        @NoArgsConstructor
 | 
			
		||||
        @AllArgsConstructor
 | 
			
		||||
        public static class PromptFuncSpec {
 | 
			
		||||
            private String name;
 | 
			
		||||
            private String description;
 | 
			
		||||
@@ -38,6 +46,9 @@ public class Tools {
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Data
 | 
			
		||||
        @Builder
 | 
			
		||||
        @NoArgsConstructor
 | 
			
		||||
        @AllArgsConstructor
 | 
			
		||||
        public static class Parameters {
 | 
			
		||||
            private String type;
 | 
			
		||||
            private Map<String, Property> properties;
 | 
			
		||||
@@ -46,6 +57,8 @@ public class Tools {
 | 
			
		||||
 | 
			
		||||
        @Data
 | 
			
		||||
        @Builder
 | 
			
		||||
        @NoArgsConstructor
 | 
			
		||||
        @AllArgsConstructor
 | 
			
		||||
        public static class Property {
 | 
			
		||||
            private String type;
 | 
			
		||||
            private String description;
 | 
			
		||||
@@ -94,10 +107,10 @@ public class Tools {
 | 
			
		||||
 | 
			
		||||
            PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
 | 
			
		||||
            parameters.setType("object");
 | 
			
		||||
            parameters.setProperties(spec.getProperties());
 | 
			
		||||
            parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties());
 | 
			
		||||
 | 
			
		||||
            List<String> requiredValues = new ArrayList<>();
 | 
			
		||||
            for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) {
 | 
			
		||||
            for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) {
 | 
			
		||||
                if (p.getValue().isRequired()) {
 | 
			
		||||
                    requiredValues.add(p.getKey());
 | 
			
		||||
                }
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user