forked from Mirror/ollama4j
		
	Extends ChatModels to use Tools and ToolCalls
This commit is contained in:
		
				
					committed by
					
						
						Markus Klenke
					
				
			
			
				
	
			
			
			
						parent
						
							e9c33ab0b2
						
					
				
				
					commit
					12bb10392e
				
			@@ -602,7 +602,7 @@ public class OllamaAPI {
 | 
			
		||||
        OllamaResult result = generate(model, prompt, raw, options, null);
 | 
			
		||||
        toolResult.setModelResult(result);
 | 
			
		||||
 | 
			
		||||
        String toolsResponse = result.getResponse();
 | 
			
		||||
        String toolsResponse = result.getContent();
 | 
			
		||||
        if (toolsResponse.contains("[TOOL_CALLS]")) {
 | 
			
		||||
            toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
 | 
			
		||||
        }
 | 
			
		||||
@@ -768,6 +768,10 @@ public class OllamaAPI {
 | 
			
		||||
    public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
			
		||||
        OllamaResult result;
 | 
			
		||||
 | 
			
		||||
        // add all registered tools to Request
 | 
			
		||||
        request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
 | 
			
		||||
 | 
			
		||||
        if (streamHandler != null) {
 | 
			
		||||
            request.setStream(true);
 | 
			
		||||
            result = requestCaller.call(request, streamHandler);
 | 
			
		||||
@@ -775,10 +779,7 @@ public class OllamaAPI {
 | 
			
		||||
            result = requestCaller.callSync(request);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // add all registered tools to Request
 | 
			
		||||
        request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
 | 
			
		||||
 | 
			
		||||
        return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
 | 
			
		||||
        return new OllamaChatResult(result.getContent(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public void registerTool(Tools.ToolSpecification toolSpecification) {
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package io.github.ollama4j.models.chat;
 | 
			
		||||
 | 
			
		||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
 | 
			
		||||
 | 
			
		||||
@@ -32,6 +33,8 @@ public class OllamaChatMessage {
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private String content;
 | 
			
		||||
 | 
			
		||||
    private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
 | 
			
		||||
 | 
			
		||||
    @JsonSerialize(using = FileToBase64Serializer.class)
 | 
			
		||||
    private List<byte[]> images;
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -38,7 +38,7 @@ public class OllamaChatRequestBuilder {
 | 
			
		||||
        request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images) {
 | 
			
		||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
 | 
			
		||||
        List<OllamaChatMessage> messages = this.request.getMessages();
 | 
			
		||||
 | 
			
		||||
        List<byte[]> binaryImages = images.stream().map(file -> {
 | 
			
		||||
@@ -50,11 +50,11 @@ public class OllamaChatRequestBuilder {
 | 
			
		||||
            }
 | 
			
		||||
        }).collect(Collectors.toList());
 | 
			
		||||
 | 
			
		||||
        messages.add(new OllamaChatMessage(role, content, binaryImages));
 | 
			
		||||
        messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
 | 
			
		||||
        return this;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) {
 | 
			
		||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List<OllamaChatToolCalls> toolCalls, String... imageUrls) {
 | 
			
		||||
        List<OllamaChatMessage> messages = this.request.getMessages();
 | 
			
		||||
        List<byte[]> binaryImages = null;
 | 
			
		||||
        if (imageUrls.length > 0) {
 | 
			
		||||
@@ -70,7 +70,7 @@ public class OllamaChatRequestBuilder {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        messages.add(new OllamaChatMessage(role, content, binaryImages));
 | 
			
		||||
        messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
 | 
			
		||||
        return this;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,16 @@
 | 
			
		||||
package io.github.ollama4j.models.chat;
 | 
			
		||||
 | 
			
		||||
import io.github.ollama4j.tools.OllamaToolCallsFunction;
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
public class OllamaChatToolCalls {
 | 
			
		||||
 | 
			
		||||
    private OllamaToolCallsFunction function;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -17,7 +17,7 @@ public class OllamaResult {
 | 
			
		||||
   *
 | 
			
		||||
   * @return String completion/response text
 | 
			
		||||
   */
 | 
			
		||||
  private final String response;
 | 
			
		||||
  private final String content;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * -- GETTER --
 | 
			
		||||
@@ -35,8 +35,8 @@ public class OllamaResult {
 | 
			
		||||
   */
 | 
			
		||||
  private long responseTime = 0;
 | 
			
		||||
 | 
			
		||||
  public OllamaResult(String response, long responseTime, int httpStatusCode) {
 | 
			
		||||
    this.response = response;
 | 
			
		||||
  public OllamaResult(String content, long responseTime, int httpStatusCode) {
 | 
			
		||||
    this.content = content;
 | 
			
		||||
    this.responseTime = responseTime;
 | 
			
		||||
    this.httpStatusCode = httpStatusCode;
 | 
			
		||||
  }
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,16 @@
 | 
			
		||||
package io.github.ollama4j.tools;
 | 
			
		||||
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
public class OllamaToolCallsFunction
 | 
			
		||||
{
 | 
			
		||||
    private String name;
 | 
			
		||||
    private Map<String,String> arguments;
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user