forked from Mirror/ollama4j
		
	Extends ChatModels to use Tools and ToolCalls
This commit is contained in:
		
				
					committed by
					
						
						Markus Klenke
					
				
			
			
				
	
			
			
			
						parent
						
							e9c33ab0b2
						
					
				
				
					commit
					12bb10392e
				
			@@ -602,7 +602,7 @@ public class OllamaAPI {
 | 
				
			|||||||
        OllamaResult result = generate(model, prompt, raw, options, null);
 | 
					        OllamaResult result = generate(model, prompt, raw, options, null);
 | 
				
			||||||
        toolResult.setModelResult(result);
 | 
					        toolResult.setModelResult(result);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        String toolsResponse = result.getResponse();
 | 
					        String toolsResponse = result.getContent();
 | 
				
			||||||
        if (toolsResponse.contains("[TOOL_CALLS]")) {
 | 
					        if (toolsResponse.contains("[TOOL_CALLS]")) {
 | 
				
			||||||
            toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
 | 
					            toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@@ -768,6 +768,10 @@ public class OllamaAPI {
 | 
				
			|||||||
    public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
 | 
					    public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
 | 
				
			||||||
        OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
					        OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
				
			||||||
        OllamaResult result;
 | 
					        OllamaResult result;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // add all registered tools to Request
 | 
				
			||||||
 | 
					        request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (streamHandler != null) {
 | 
					        if (streamHandler != null) {
 | 
				
			||||||
            request.setStream(true);
 | 
					            request.setStream(true);
 | 
				
			||||||
            result = requestCaller.call(request, streamHandler);
 | 
					            result = requestCaller.call(request, streamHandler);
 | 
				
			||||||
@@ -775,10 +779,7 @@ public class OllamaAPI {
 | 
				
			|||||||
            result = requestCaller.callSync(request);
 | 
					            result = requestCaller.callSync(request);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // add all registered tools to Request
 | 
					        return new OllamaChatResult(result.getContent(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
 | 
				
			||||||
        request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public void registerTool(Tools.ToolSpecification toolSpecification) {
 | 
					    public void registerTool(Tools.ToolSpecification toolSpecification) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ package io.github.ollama4j.models.chat;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
 | 
					import static io.github.ollama4j.utils.Utils.getObjectMapper;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
					import com.fasterxml.jackson.core.JsonProcessingException;
 | 
				
			||||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
 | 
					import com.fasterxml.jackson.databind.annotation.JsonSerialize;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -32,6 +33,8 @@ public class OllamaChatMessage {
 | 
				
			|||||||
    @NonNull
 | 
					    @NonNull
 | 
				
			||||||
    private String content;
 | 
					    private String content;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @JsonSerialize(using = FileToBase64Serializer.class)
 | 
					    @JsonSerialize(using = FileToBase64Serializer.class)
 | 
				
			||||||
    private List<byte[]> images;
 | 
					    private List<byte[]> images;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -38,7 +38,7 @@ public class OllamaChatRequestBuilder {
 | 
				
			|||||||
        request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
 | 
					        request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images) {
 | 
					    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
 | 
				
			||||||
        List<OllamaChatMessage> messages = this.request.getMessages();
 | 
					        List<OllamaChatMessage> messages = this.request.getMessages();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        List<byte[]> binaryImages = images.stream().map(file -> {
 | 
					        List<byte[]> binaryImages = images.stream().map(file -> {
 | 
				
			||||||
@@ -50,11 +50,11 @@ public class OllamaChatRequestBuilder {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        }).collect(Collectors.toList());
 | 
					        }).collect(Collectors.toList());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        messages.add(new OllamaChatMessage(role, content, binaryImages));
 | 
					        messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
 | 
				
			||||||
        return this;
 | 
					        return this;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) {
 | 
					    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List<OllamaChatToolCalls> toolCalls, String... imageUrls) {
 | 
				
			||||||
        List<OllamaChatMessage> messages = this.request.getMessages();
 | 
					        List<OllamaChatMessage> messages = this.request.getMessages();
 | 
				
			||||||
        List<byte[]> binaryImages = null;
 | 
					        List<byte[]> binaryImages = null;
 | 
				
			||||||
        if (imageUrls.length > 0) {
 | 
					        if (imageUrls.length > 0) {
 | 
				
			||||||
@@ -70,7 +70,7 @@ public class OllamaChatRequestBuilder {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        messages.add(new OllamaChatMessage(role, content, binaryImages));
 | 
					        messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
 | 
				
			||||||
        return this;
 | 
					        return this;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,16 @@
 | 
				
			|||||||
 | 
					package io.github.ollama4j.models.chat;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.OllamaToolCallsFunction;
 | 
				
			||||||
 | 
					import lombok.AllArgsConstructor;
 | 
				
			||||||
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Data
 | 
				
			||||||
 | 
					@NoArgsConstructor
 | 
				
			||||||
 | 
					@AllArgsConstructor
 | 
				
			||||||
 | 
					public class OllamaChatToolCalls {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private OllamaToolCallsFunction function;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -17,7 +17,7 @@ public class OllamaResult {
 | 
				
			|||||||
   *
 | 
					   *
 | 
				
			||||||
   * @return String completion/response text
 | 
					   * @return String completion/response text
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  private final String response;
 | 
					  private final String content;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * -- GETTER --
 | 
					   * -- GETTER --
 | 
				
			||||||
@@ -35,8 +35,8 @@ public class OllamaResult {
 | 
				
			|||||||
   */
 | 
					   */
 | 
				
			||||||
  private long responseTime = 0;
 | 
					  private long responseTime = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public OllamaResult(String response, long responseTime, int httpStatusCode) {
 | 
					  public OllamaResult(String content, long responseTime, int httpStatusCode) {
 | 
				
			||||||
    this.response = response;
 | 
					    this.content = content;
 | 
				
			||||||
    this.responseTime = responseTime;
 | 
					    this.responseTime = responseTime;
 | 
				
			||||||
    this.httpStatusCode = httpStatusCode;
 | 
					    this.httpStatusCode = httpStatusCode;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,16 @@
 | 
				
			|||||||
 | 
					package io.github.ollama4j.tools;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import lombok.AllArgsConstructor;
 | 
				
			||||||
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Data
 | 
				
			||||||
 | 
					@NoArgsConstructor
 | 
				
			||||||
 | 
					@AllArgsConstructor
 | 
				
			||||||
 | 
					public class OllamaToolCallsFunction
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					    private String name;
 | 
				
			||||||
 | 
					    private Map<String,String> arguments;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -10,6 +10,8 @@ import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
 | 
				
			|||||||
import io.github.ollama4j.models.chat.OllamaChatResult;
 | 
					import io.github.ollama4j.models.chat.OllamaChatResult;
 | 
				
			||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
 | 
					import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
 | 
				
			||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
					import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.ToolFunction;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.Tools;
 | 
				
			||||||
import io.github.ollama4j.utils.OptionsBuilder;
 | 
					import io.github.ollama4j.utils.OptionsBuilder;
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
import org.junit.jupiter.api.BeforeEach;
 | 
					import org.junit.jupiter.api.BeforeEach;
 | 
				
			||||||
@@ -24,9 +26,7 @@ import java.io.InputStream;
 | 
				
			|||||||
import java.net.ConnectException;
 | 
					import java.net.ConnectException;
 | 
				
			||||||
import java.net.URISyntaxException;
 | 
					import java.net.URISyntaxException;
 | 
				
			||||||
import java.net.http.HttpConnectTimeoutException;
 | 
					import java.net.http.HttpConnectTimeoutException;
 | 
				
			||||||
import java.util.List;
 | 
					import java.util.*;
 | 
				
			||||||
import java.util.Objects;
 | 
					 | 
				
			||||||
import java.util.Properties;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import static org.junit.jupiter.api.Assertions.*;
 | 
					import static org.junit.jupiter.api.Assertions.*;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -230,18 +230,47 @@ class TestRealAPIs {
 | 
				
			|||||||
    void testChatWithTools() {
 | 
					    void testChatWithTools() {
 | 
				
			||||||
        testEndpointReachability();
 | 
					        testEndpointReachability();
 | 
				
			||||||
        try {
 | 
					        try {
 | 
				
			||||||
 | 
					            ollamaAPI.setVerbose(true);
 | 
				
			||||||
            OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
 | 
					            OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
 | 
				
			||||||
            OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
 | 
					
 | 
				
			||||||
                            "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
 | 
					            final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
 | 
				
			||||||
 | 
					                    .functionName("get-employee-details")
 | 
				
			||||||
 | 
					                    .functionDescription("Get employee details from the database")
 | 
				
			||||||
 | 
					                    .toolPrompt(
 | 
				
			||||||
 | 
					                            Tools.PromptFuncDefinition.builder().type("function").function(
 | 
				
			||||||
 | 
					                                    Tools.PromptFuncDefinition.PromptFuncSpec.builder()
 | 
				
			||||||
 | 
					                                            .name("get-employee-details")
 | 
				
			||||||
 | 
					                                            .description("Get employee details from the database")
 | 
				
			||||||
 | 
					                                            .parameters(
 | 
				
			||||||
 | 
					                                                    Tools.PromptFuncDefinition.Parameters.builder()
 | 
				
			||||||
 | 
					                                                            .type("object")
 | 
				
			||||||
 | 
					                                                            .properties(
 | 
				
			||||||
 | 
					                                                            new Tools.PropsBuilder()
 | 
				
			||||||
 | 
					                                                                    .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
 | 
				
			||||||
 | 
					                                                                    .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
 | 
				
			||||||
 | 
					                                                                    .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
 | 
				
			||||||
 | 
					                                                                    .build()
 | 
				
			||||||
 | 
					                                                    )
 | 
				
			||||||
 | 
					                                                            .required(List.of("employee-name"))
 | 
				
			||||||
 | 
					                                                            .build()
 | 
				
			||||||
 | 
					                                            ).build()
 | 
				
			||||||
 | 
					                            ).build()
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    .toolFunction(new DBQueryFunction())
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            ollamaAPI.registerTool(databaseQueryToolSpecification);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OllamaChatRequest requestModel = builder
 | 
				
			||||||
                    .withMessage(OllamaChatMessageRole.USER,
 | 
					                    .withMessage(OllamaChatMessageRole.USER,
 | 
				
			||||||
                            "What is the capital of France? And what's France's connection with Mona Lisa?")
 | 
					                            "Give me the details of the employee named 'Rahul Kumar'?")
 | 
				
			||||||
                    .build();
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
					            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
 | 
					            System.err.println("Response: "  + chatResult);
 | 
				
			||||||
            assertNotNull(chatResult);
 | 
					            assertNotNull(chatResult);
 | 
				
			||||||
            assertFalse(chatResult.getResponse().isBlank());
 | 
					            assertFalse(chatResult.getResponse().isBlank());
 | 
				
			||||||
            assertTrue(chatResult.getResponse().startsWith("NI"));
 | 
					            assertEquals(2, chatResult.getChatHistory().size());
 | 
				
			||||||
            assertEquals(3, chatResult.getChatHistory().size());
 | 
					 | 
				
			||||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
					        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
				
			||||||
            fail(e);
 | 
					            fail(e);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@@ -402,6 +431,14 @@ class TestRealAPIs {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DBQueryFunction implements ToolFunction {
 | 
				
			||||||
 | 
					    @Override
 | 
				
			||||||
 | 
					    public Object apply(Map<String, Object> arguments) {
 | 
				
			||||||
 | 
					        // perform DB operations here
 | 
				
			||||||
 | 
					        return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
class Config {
 | 
					class Config {
 | 
				
			||||||
    private String ollamaURL;
 | 
					    private String ollamaURL;
 | 
				
			||||||
@@ -426,4 +463,6 @@ class Config {
 | 
				
			|||||||
            throw new RuntimeException("Error loading properties", e);
 | 
					            throw new RuntimeException("Error loading properties", e);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user