forked from Mirror/ollama4j
		
	Extends ToolSpec to have PromptDef for ChatRequests
This commit is contained in:
		
				
					committed by
					
						
						Markus Klenke
					
				
			
			
				
	
			
			
			
						parent
						
							ff3344616c
						
					
				
				
					commit
					903a8176cd
				
			@@ -774,11 +774,15 @@ public class OllamaAPI {
 | 
				
			|||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            result = requestCaller.callSync(request);
 | 
					            result = requestCaller.callSync(request);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        //add registered Tools to Request
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
 | 
					        return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public void registerTool(Tools.ToolSpecification toolSpecification) {
 | 
					    public void registerTool(Tools.ToolSpecification toolSpecification) {
 | 
				
			||||||
        toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
 | 
					        toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@@ -871,7 +875,7 @@ public class OllamaAPI {
 | 
				
			|||||||
        try {
 | 
					        try {
 | 
				
			||||||
            String methodName = toolFunctionCallSpec.getName();
 | 
					            String methodName = toolFunctionCallSpec.getName();
 | 
				
			||||||
            Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
 | 
					            Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
 | 
				
			||||||
            ToolFunction function = toolRegistry.getFunction(methodName);
 | 
					            ToolFunction function = toolRegistry.getToolFunction(methodName);
 | 
				
			||||||
            if (verbose) {
 | 
					            if (verbose) {
 | 
				
			||||||
                logger.debug("Invoking function {} with arguments {}", methodName, arguments);
 | 
					                logger.debug("Invoking function {} with arguments {}", methodName, arguments);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@ package io.github.ollama4j.models.chat;
 | 
				
			|||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import io.github.ollama4j.models.request.OllamaCommonRequest;
 | 
					import io.github.ollama4j.models.request.OllamaCommonRequest;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.Tools;
 | 
				
			||||||
import io.github.ollama4j.utils.OllamaRequestBody;
 | 
					import io.github.ollama4j.utils.OllamaRequestBody;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.Getter;
 | 
					import lombok.Getter;
 | 
				
			||||||
@@ -21,6 +22,8 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  private List<OllamaChatMessage> messages;
 | 
					  private List<OllamaChatMessage> messages;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  private List<Tools.PromptFuncDefinition> tools;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public OllamaChatRequest() {}
 | 
					  public OllamaChatRequest() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
 | 
					  public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,13 +4,14 @@ import java.util.HashMap;
 | 
				
			|||||||
import java.util.Map;
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public class ToolRegistry {
 | 
					public class ToolRegistry {
 | 
				
			||||||
    private final Map<String, ToolFunction> functionMap = new HashMap<>();
 | 
					    private final Map<String, Tools.ToolSpecification> tools = new HashMap<>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public ToolFunction getFunction(String name) {
 | 
					    public ToolFunction getToolFunction(String name) {
 | 
				
			||||||
        return functionMap.get(name);
 | 
					        final Tools.ToolSpecification toolSpecification = tools.get(name);
 | 
				
			||||||
 | 
					        return toolSpecification !=null ? toolSpecification.getToolFunction() : null ;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public void addFunction(String name, ToolFunction function) {
 | 
					    public void addTool (String name, Tools.ToolSpecification specification) {
 | 
				
			||||||
        functionMap.put(name, function);
 | 
					        tools.put(name, specification);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,8 +6,10 @@ import com.fasterxml.jackson.annotation.JsonInclude;
 | 
				
			|||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
					import com.fasterxml.jackson.core.JsonProcessingException;
 | 
				
			||||||
import io.github.ollama4j.utils.Utils;
 | 
					import io.github.ollama4j.utils.Utils;
 | 
				
			||||||
 | 
					import lombok.AllArgsConstructor;
 | 
				
			||||||
import lombok.Builder;
 | 
					import lombok.Builder;
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.util.ArrayList;
 | 
					import java.util.ArrayList;
 | 
				
			||||||
import java.util.HashMap;
 | 
					import java.util.HashMap;
 | 
				
			||||||
@@ -20,17 +22,23 @@ public class Tools {
 | 
				
			|||||||
    public static class ToolSpecification {
 | 
					    public static class ToolSpecification {
 | 
				
			||||||
        private String functionName;
 | 
					        private String functionName;
 | 
				
			||||||
        private String functionDescription;
 | 
					        private String functionDescription;
 | 
				
			||||||
        private Map<String, PromptFuncDefinition.Property> properties;
 | 
					        private PromptFuncDefinition toolPrompt;
 | 
				
			||||||
        private ToolFunction toolDefinition;
 | 
					        private ToolFunction toolFunction;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Data
 | 
					    @Data
 | 
				
			||||||
    @JsonIgnoreProperties(ignoreUnknown = true)
 | 
					    @JsonIgnoreProperties(ignoreUnknown = true)
 | 
				
			||||||
 | 
					    @Builder
 | 
				
			||||||
 | 
					    @NoArgsConstructor
 | 
				
			||||||
 | 
					    @AllArgsConstructor
 | 
				
			||||||
    public static class PromptFuncDefinition {
 | 
					    public static class PromptFuncDefinition {
 | 
				
			||||||
        private String type;
 | 
					        private String type;
 | 
				
			||||||
        private PromptFuncSpec function;
 | 
					        private PromptFuncSpec function;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Data
 | 
					        @Data
 | 
				
			||||||
 | 
					        @Builder
 | 
				
			||||||
 | 
					        @NoArgsConstructor
 | 
				
			||||||
 | 
					        @AllArgsConstructor
 | 
				
			||||||
        public static class PromptFuncSpec {
 | 
					        public static class PromptFuncSpec {
 | 
				
			||||||
            private String name;
 | 
					            private String name;
 | 
				
			||||||
            private String description;
 | 
					            private String description;
 | 
				
			||||||
@@ -38,6 +46,9 @@ public class Tools {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Data
 | 
					        @Data
 | 
				
			||||||
 | 
					        @Builder
 | 
				
			||||||
 | 
					        @NoArgsConstructor
 | 
				
			||||||
 | 
					        @AllArgsConstructor
 | 
				
			||||||
        public static class Parameters {
 | 
					        public static class Parameters {
 | 
				
			||||||
            private String type;
 | 
					            private String type;
 | 
				
			||||||
            private Map<String, Property> properties;
 | 
					            private Map<String, Property> properties;
 | 
				
			||||||
@@ -46,6 +57,8 @@ public class Tools {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        @Data
 | 
					        @Data
 | 
				
			||||||
        @Builder
 | 
					        @Builder
 | 
				
			||||||
 | 
					        @NoArgsConstructor
 | 
				
			||||||
 | 
					        @AllArgsConstructor
 | 
				
			||||||
        public static class Property {
 | 
					        public static class Property {
 | 
				
			||||||
            private String type;
 | 
					            private String type;
 | 
				
			||||||
            private String description;
 | 
					            private String description;
 | 
				
			||||||
@@ -94,10 +107,10 @@ public class Tools {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
 | 
					            PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
 | 
				
			||||||
            parameters.setType("object");
 | 
					            parameters.setType("object");
 | 
				
			||||||
            parameters.setProperties(spec.getProperties());
 | 
					            parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            List<String> requiredValues = new ArrayList<>();
 | 
					            List<String> requiredValues = new ArrayList<>();
 | 
				
			||||||
            for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) {
 | 
					            for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) {
 | 
				
			||||||
                if (p.getValue().isRequired()) {
 | 
					                if (p.getValue().isRequired()) {
 | 
				
			||||||
                    requiredValues.add(p.getKey());
 | 
					                    requiredValues.add(p.getKey());
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user