mirror of
				https://github.com/amithkoujalgi/ollama4j.git
				synced 2025-11-04 02:20:50 +01:00 
			
		
		
		
	Merge pull request #82 from AgentSchmecker/feature/toolextension_for_chat_model
Enable chat API to use Tools
This commit is contained in:
		@@ -33,7 +33,7 @@ public class Main {
 | 
			
		||||
        // start conversation with model
 | 
			
		||||
        OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
			
		||||
 | 
			
		||||
        System.out.println("First answer: " + chatResult.getResponse());
 | 
			
		||||
        System.out.println("First answer: " + chatResult.getResponseModel().getMessage().getContent());
 | 
			
		||||
 | 
			
		||||
        // create next userQuestion
 | 
			
		||||
        requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is the second largest city?").build();
 | 
			
		||||
@@ -41,7 +41,7 @@ public class Main {
 | 
			
		||||
        // "continue" conversation with model
 | 
			
		||||
        chatResult = ollamaAPI.chat(requestModel);
 | 
			
		||||
 | 
			
		||||
        System.out.println("Second answer: " + chatResult.getResponse());
 | 
			
		||||
        System.out.println("Second answer: " + chatResult.getResponseModel().getMessage().getContent());
 | 
			
		||||
 | 
			
		||||
        System.out.println("Chat History: " + chatResult.getChatHistory());
 | 
			
		||||
    }
 | 
			
		||||
@@ -205,7 +205,7 @@ public class Main {
 | 
			
		||||
        // start conversation with model
 | 
			
		||||
        OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
			
		||||
 | 
			
		||||
        System.out.println(chatResult.getResponse());
 | 
			
		||||
        System.out.println(chatResult.getResponseModel());
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -244,7 +244,7 @@ public class Main {
 | 
			
		||||
                                new File("/path/to/image"))).build();
 | 
			
		||||
 | 
			
		||||
        OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
			
		||||
        System.out.println("First answer: " + chatResult.getResponse());
 | 
			
		||||
        System.out.println("First answer: " + chatResult.getResponseModel());
 | 
			
		||||
 | 
			
		||||
        builder.reset();
 | 
			
		||||
 | 
			
		||||
@@ -254,7 +254,7 @@ public class Main {
 | 
			
		||||
                        .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
 | 
			
		||||
 | 
			
		||||
        chatResult = ollamaAPI.chat(requestModel);
 | 
			
		||||
        System.out.println("Second answer: " + chatResult.getResponse());
 | 
			
		||||
        System.out.println("Second answer: " + chatResult.getResponseModel());
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 
 | 
			
		||||
@@ -345,6 +345,125 @@ Rahul Kumar, Address: King St, Hyderabad, India, Phone: 9876543210}`
 | 
			
		||||
 | 
			
		||||
::::
 | 
			
		||||
 | 
			
		||||
### Using tools in Chat-API
 | 
			
		||||
 | 
			
		||||
Instead of using the specific `ollamaAPI.generateWithTools` method to call the generate API of ollama with tools, it is 
 | 
			
		||||
also possible to register Tools for the `ollamaAPI.chat` methods. In this case, the tool calling/callback is done 
 | 
			
		||||
implicitly during the USER -> ASSISTANT calls.
 | 
			
		||||
 | 
			
		||||
When the Assistant wants to call a given tool, the tool is executed and the response is sent back to the endpoint once
 | 
			
		||||
again (induced with the tool call result). 
 | 
			
		||||
 | 
			
		||||
#### Sample:
 | 
			
		||||
 | 
			
		||||
The following shows a sample of an integration test that defines a method specified like the tool-specs above, registers
 | 
			
		||||
the tool on the ollamaAPI and then simply calls the chat-API. All intermediate tool calling is wrapped inside the api 
 | 
			
		||||
call.
 | 
			
		||||
 | 
			
		||||
```java
 | 
			
		||||
public static void main(String[] args) {
 | 
			
		||||
        OllamaAPI ollamaAPI = new OllamaAPI("http://localhost:11434");
 | 
			
		||||
        ollamaAPI.setVerbose(true);
 | 
			
		||||
        OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("llama3.2:1b");
 | 
			
		||||
 | 
			
		||||
        final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
 | 
			
		||||
                .functionName("get-employee-details")
 | 
			
		||||
                .functionDescription("Get employee details from the database")
 | 
			
		||||
                .toolPrompt(
 | 
			
		||||
                        Tools.PromptFuncDefinition.builder().type("function").function(
 | 
			
		||||
                                Tools.PromptFuncDefinition.PromptFuncSpec.builder()
 | 
			
		||||
                                        .name("get-employee-details")
 | 
			
		||||
                                        .description("Get employee details from the database")
 | 
			
		||||
                                        .parameters(
 | 
			
		||||
                                                Tools.PromptFuncDefinition.Parameters.builder()
 | 
			
		||||
                                                        .type("object")
 | 
			
		||||
                                                        .properties(
 | 
			
		||||
                                                                new Tools.PropsBuilder()
 | 
			
		||||
                                                                        .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
 | 
			
		||||
                                                                        .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
 | 
			
		||||
                                                                        .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
 | 
			
		||||
                                                                        .build()
 | 
			
		||||
                                                        )
 | 
			
		||||
                                                        .required(List.of("employee-name"))
 | 
			
		||||
                                                        .build()
 | 
			
		||||
                                        ).build()
 | 
			
		||||
                        ).build()
 | 
			
		||||
                )
 | 
			
		||||
                .toolFunction(new DBQueryFunction())
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        ollamaAPI.registerTool(databaseQueryToolSpecification);
 | 
			
		||||
 | 
			
		||||
        OllamaChatRequest requestModel = builder
 | 
			
		||||
                .withMessage(OllamaChatMessageRole.USER,
 | 
			
		||||
                        "Give me the ID of the employee named 'Rahul Kumar'?")
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
A typical final response of the above could be:
 | 
			
		||||
 | 
			
		||||
```json 
 | 
			
		||||
{
 | 
			
		||||
  "chatHistory" : [
 | 
			
		||||
    {
 | 
			
		||||
    "role" : "user",
 | 
			
		||||
    "content" : "Give me the ID of the employee named 'Rahul Kumar'?",
 | 
			
		||||
    "images" : null,
 | 
			
		||||
    "tool_calls" : [ ]
 | 
			
		||||
  }, {
 | 
			
		||||
    "role" : "assistant",
 | 
			
		||||
    "content" : "",
 | 
			
		||||
    "images" : null,
 | 
			
		||||
    "tool_calls" : [ {
 | 
			
		||||
      "function" : {
 | 
			
		||||
        "name" : "get-employee-details",
 | 
			
		||||
        "arguments" : {
 | 
			
		||||
          "employee-name" : "Rahul Kumar"
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    } ]
 | 
			
		||||
  }, {
 | 
			
		||||
    "role" : "tool",
 | 
			
		||||
    "content" : "[TOOL_RESULTS]get-employee-details([employee-name]) : Employee Details {ID: b4bf186c-2ee1-44cc-8856-53b8b6a50f85, Name: Rahul Kumar, Address: null, Phone: null}[/TOOL_RESULTS]",
 | 
			
		||||
    "images" : null,
 | 
			
		||||
    "tool_calls" : null
 | 
			
		||||
  }, {
 | 
			
		||||
    "role" : "assistant",
 | 
			
		||||
    "content" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
 | 
			
		||||
    "images" : null,
 | 
			
		||||
    "tool_calls" : null
 | 
			
		||||
  } ],
 | 
			
		||||
  "responseModel" : {
 | 
			
		||||
    "model" : "llama3.2:1b",
 | 
			
		||||
    "message" : {
 | 
			
		||||
      "role" : "assistant",
 | 
			
		||||
      "content" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
 | 
			
		||||
      "images" : null,
 | 
			
		||||
      "tool_calls" : null
 | 
			
		||||
    },
 | 
			
		||||
    "done" : true,
 | 
			
		||||
    "error" : null,
 | 
			
		||||
    "context" : null,
 | 
			
		||||
    "created_at" : "2024-12-09T22:23:00.4940078Z",
 | 
			
		||||
    "done_reason" : "stop",
 | 
			
		||||
    "total_duration" : 2313709900,
 | 
			
		||||
    "load_duration" : 14494700,
 | 
			
		||||
    "prompt_eval_duration" : 772000000,
 | 
			
		||||
    "eval_duration" : 1188000000,
 | 
			
		||||
    "prompt_eval_count" : 166,
 | 
			
		||||
    "eval_count" : 41
 | 
			
		||||
  },
 | 
			
		||||
  "response" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
 | 
			
		||||
  "httpStatusCode" : 200,
 | 
			
		||||
  "responseTime" : 2313709900
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This tool calling can also be done using the streaming API.
 | 
			
		||||
 | 
			
		||||
### Potential Improvements
 | 
			
		||||
 | 
			
		||||
Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool
 | 
			
		||||
 
 | 
			
		||||
@@ -59,6 +59,10 @@ public class OllamaAPI {
 | 
			
		||||
     */
 | 
			
		||||
    @Setter
 | 
			
		||||
    private boolean verbose = true;
 | 
			
		||||
 | 
			
		||||
    @Setter
 | 
			
		||||
    private int maxChatToolCallRetries = 3;
 | 
			
		||||
 | 
			
		||||
    private BasicAuth basicAuth;
 | 
			
		||||
 | 
			
		||||
    private final ToolRegistry toolRegistry = new ToolRegistry();
 | 
			
		||||
@@ -767,18 +771,44 @@ public class OllamaAPI {
 | 
			
		||||
     */
 | 
			
		||||
    public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
			
		||||
        OllamaResult result;
 | 
			
		||||
        OllamaChatResult result;
 | 
			
		||||
 | 
			
		||||
        // add all registered tools to Request
 | 
			
		||||
        request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
 | 
			
		||||
 | 
			
		||||
        if (streamHandler != null) {
 | 
			
		||||
            request.setStream(true);
 | 
			
		||||
            result = requestCaller.call(request, streamHandler);
 | 
			
		||||
        } else {
 | 
			
		||||
            result = requestCaller.callSync(request);
 | 
			
		||||
        }
 | 
			
		||||
        return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
 | 
			
		||||
 | 
			
		||||
        // check if toolCallIsWanted
 | 
			
		||||
        List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
			
		||||
        int toolCallTries = 0;
 | 
			
		||||
        while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries){
 | 
			
		||||
            for (OllamaChatToolCalls toolCall : toolCalls){
 | 
			
		||||
                String toolName = toolCall.getFunction().getName();
 | 
			
		||||
                ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
 | 
			
		||||
                Map<String, Object> arguments = toolCall.getFunction().getArguments();
 | 
			
		||||
                Object res = toolFunction.apply(arguments);
 | 
			
		||||
                request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if (streamHandler != null) {
 | 
			
		||||
                result = requestCaller.call(request, streamHandler);
 | 
			
		||||
            } else {
 | 
			
		||||
                result = requestCaller.callSync(request);
 | 
			
		||||
            }
 | 
			
		||||
            toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
			
		||||
            toolCallTries++;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return result;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public void registerTool(Tools.ToolSpecification toolSpecification) {
 | 
			
		||||
        toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
 | 
			
		||||
        toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
@@ -871,7 +901,7 @@ public class OllamaAPI {
 | 
			
		||||
        try {
 | 
			
		||||
            String methodName = toolFunctionCallSpec.getName();
 | 
			
		||||
            Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
 | 
			
		||||
            ToolFunction function = toolRegistry.getFunction(methodName);
 | 
			
		||||
            ToolFunction function = toolRegistry.getToolFunction(methodName);
 | 
			
		||||
            if (verbose) {
 | 
			
		||||
                logger.debug("Invoking function {} with arguments {}", methodName, arguments);
 | 
			
		||||
            }
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package io.github.ollama4j.models.chat;
 | 
			
		||||
 | 
			
		||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
 | 
			
		||||
 | 
			
		||||
@@ -32,6 +33,8 @@ public class OllamaChatMessage {
 | 
			
		||||
    @NonNull
 | 
			
		||||
    private String content;
 | 
			
		||||
 | 
			
		||||
    private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
 | 
			
		||||
 | 
			
		||||
    @JsonSerialize(using = FileToBase64Serializer.class)
 | 
			
		||||
    private List<byte[]> images;
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package io.github.ollama4j.models.chat;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
import io.github.ollama4j.models.request.OllamaCommonRequest;
 | 
			
		||||
import io.github.ollama4j.tools.Tools;
 | 
			
		||||
import io.github.ollama4j.utils.OllamaRequestBody;
 | 
			
		||||
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
@@ -21,6 +22,8 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ
 | 
			
		||||
 | 
			
		||||
  private List<OllamaChatMessage> messages;
 | 
			
		||||
 | 
			
		||||
  private List<Tools.PromptFuncDefinition> tools;
 | 
			
		||||
 | 
			
		||||
  public OllamaChatRequest() {}
 | 
			
		||||
 | 
			
		||||
  public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,7 @@ import java.io.IOException;
 | 
			
		||||
import java.net.URISyntaxException;
 | 
			
		||||
import java.nio.file.Files;
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.stream.Collectors;
 | 
			
		||||
 | 
			
		||||
@@ -38,7 +39,11 @@ public class OllamaChatRequestBuilder {
 | 
			
		||||
        request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images) {
 | 
			
		||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){
 | 
			
		||||
        return withMessage(role,content, Collections.emptyList());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
 | 
			
		||||
        List<OllamaChatMessage> messages = this.request.getMessages();
 | 
			
		||||
 | 
			
		||||
        List<byte[]> binaryImages = images.stream().map(file -> {
 | 
			
		||||
@@ -50,11 +55,11 @@ public class OllamaChatRequestBuilder {
 | 
			
		||||
            }
 | 
			
		||||
        }).collect(Collectors.toList());
 | 
			
		||||
 | 
			
		||||
        messages.add(new OllamaChatMessage(role, content, binaryImages));
 | 
			
		||||
        messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
 | 
			
		||||
        return this;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) {
 | 
			
		||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List<OllamaChatToolCalls> toolCalls, String... imageUrls) {
 | 
			
		||||
        List<OllamaChatMessage> messages = this.request.getMessages();
 | 
			
		||||
        List<byte[]> binaryImages = null;
 | 
			
		||||
        if (imageUrls.length > 0) {
 | 
			
		||||
@@ -70,7 +75,7 @@ public class OllamaChatRequestBuilder {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        messages.add(new OllamaChatMessage(role, content, binaryImages));
 | 
			
		||||
        messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
 | 
			
		||||
        return this;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,28 +2,54 @@ package io.github.ollama4j.models.chat;
 | 
			
		||||
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
 | 
			
		||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
 | 
			
		||||
 * {@link OllamaChatMessageRole#ASSISTANT} role.
 | 
			
		||||
 */
 | 
			
		||||
public class OllamaChatResult extends OllamaResult {
 | 
			
		||||
@Getter
 | 
			
		||||
public class OllamaChatResult {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    private List<OllamaChatMessage> chatHistory;
 | 
			
		||||
 | 
			
		||||
    public OllamaChatResult(String response, long responseTime, int httpStatusCode, List<OllamaChatMessage> chatHistory) {
 | 
			
		||||
        super(response, responseTime, httpStatusCode);
 | 
			
		||||
    private OllamaChatResponseModel responseModel;
 | 
			
		||||
 | 
			
		||||
    public OllamaChatResult(OllamaChatResponseModel responseModel, List<OllamaChatMessage> chatHistory) {
 | 
			
		||||
        this.chatHistory = chatHistory;
 | 
			
		||||
        appendAnswerToChatHistory(response);
 | 
			
		||||
        this.responseModel = responseModel;
 | 
			
		||||
        appendAnswerToChatHistory(responseModel);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public List<OllamaChatMessage> getChatHistory() {
 | 
			
		||||
        return chatHistory;
 | 
			
		||||
    private void appendAnswerToChatHistory(OllamaChatResponseModel response) {
 | 
			
		||||
        this.chatHistory.add(response.getMessage());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private void appendAnswerToChatHistory(String answer) {
 | 
			
		||||
        OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer);
 | 
			
		||||
        this.chatHistory.add(assistantMessage);
 | 
			
		||||
    @Override
 | 
			
		||||
    public String toString() {
 | 
			
		||||
        try {
 | 
			
		||||
            return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
 | 
			
		||||
        } catch (JsonProcessingException e) {
 | 
			
		||||
            throw new RuntimeException(e);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Deprecated
 | 
			
		||||
    public String getResponse(){
 | 
			
		||||
        return responseModel != null ? responseModel.getMessage().getContent() : "";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Deprecated
 | 
			
		||||
    public int getHttpStatusCode(){
 | 
			
		||||
        return 200;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Deprecated
 | 
			
		||||
    public long getResponseTime(){
 | 
			
		||||
        return responseModel != null ? responseModel.getTotalDuration() : 0L;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,16 @@
 | 
			
		||||
package io.github.ollama4j.models.chat;
 | 
			
		||||
 | 
			
		||||
import io.github.ollama4j.tools.OllamaToolCallsFunction;
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
public class OllamaChatToolCalls {
 | 
			
		||||
 | 
			
		||||
    private OllamaToolCallsFunction function;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -3,17 +3,24 @@ package io.github.ollama4j.models.request;
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import com.fasterxml.jackson.core.type.TypeReference;
 | 
			
		||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
 | 
			
		||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
 | 
			
		||||
import io.github.ollama4j.models.chat.*;
 | 
			
		||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
 | 
			
		||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
			
		||||
import io.github.ollama4j.utils.OllamaRequestBody;
 | 
			
		||||
import io.github.ollama4j.tools.Tools;
 | 
			
		||||
import io.github.ollama4j.utils.Utils;
 | 
			
		||||
import org.slf4j.Logger;
 | 
			
		||||
import org.slf4j.LoggerFactory;
 | 
			
		||||
 | 
			
		||||
import java.io.BufferedReader;
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import java.io.InputStream;
 | 
			
		||||
import java.io.InputStreamReader;
 | 
			
		||||
import java.net.URI;
 | 
			
		||||
import java.net.http.HttpClient;
 | 
			
		||||
import java.net.http.HttpRequest;
 | 
			
		||||
import java.net.http.HttpResponse;
 | 
			
		||||
import java.nio.charset.StandardCharsets;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Specialization class for requests
 | 
			
		||||
@@ -64,9 +71,75 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
 | 
			
		||||
    public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler)
 | 
			
		||||
            throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        streamObserver = new OllamaChatStreamObserver(streamHandler);
 | 
			
		||||
        return super.callSync(body);
 | 
			
		||||
        return callSync(body);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public OllamaChatResult callSync(OllamaChatRequest body) throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        // Create Request
 | 
			
		||||
        HttpClient httpClient = HttpClient.newHttpClient();
 | 
			
		||||
        URI uri = URI.create(getHost() + getEndpointSuffix());
 | 
			
		||||
        HttpRequest.Builder requestBuilder =
 | 
			
		||||
                getRequestBuilderDefault(uri)
 | 
			
		||||
                        .POST(
 | 
			
		||||
                                body.getBodyPublisher());
 | 
			
		||||
        HttpRequest request = requestBuilder.build();
 | 
			
		||||
        if (isVerbose()) LOG.info("Asking model: " + body.toString());
 | 
			
		||||
        HttpResponse<InputStream> response =
 | 
			
		||||
                httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
 | 
			
		||||
 | 
			
		||||
        int statusCode = response.statusCode();
 | 
			
		||||
        InputStream responseBodyStream = response.body();
 | 
			
		||||
        StringBuilder responseBuffer = new StringBuilder();
 | 
			
		||||
        OllamaChatResponseModel ollamaChatResponseModel = null;
 | 
			
		||||
        List<OllamaChatToolCalls> wantedToolsForStream = null;
 | 
			
		||||
        try (BufferedReader reader =
 | 
			
		||||
                     new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
 | 
			
		||||
 | 
			
		||||
            String line;
 | 
			
		||||
            while ((line = reader.readLine()) != null) {
 | 
			
		||||
                if (statusCode == 404) {
 | 
			
		||||
                    LOG.warn("Status code: 404 (Not Found)");
 | 
			
		||||
                    OllamaErrorResponse ollamaResponseModel =
 | 
			
		||||
                            Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
 | 
			
		||||
                    responseBuffer.append(ollamaResponseModel.getError());
 | 
			
		||||
                } else if (statusCode == 401) {
 | 
			
		||||
                    LOG.warn("Status code: 401 (Unauthorized)");
 | 
			
		||||
                    OllamaErrorResponse ollamaResponseModel =
 | 
			
		||||
                            Utils.getObjectMapper()
 | 
			
		||||
                                    .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
 | 
			
		||||
                    responseBuffer.append(ollamaResponseModel.getError());
 | 
			
		||||
                } else if (statusCode == 400) {
 | 
			
		||||
                    LOG.warn("Status code: 400 (Bad Request)");
 | 
			
		||||
                    OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
 | 
			
		||||
                            OllamaErrorResponse.class);
 | 
			
		||||
                    responseBuffer.append(ollamaResponseModel.getError());
 | 
			
		||||
                } else {
 | 
			
		||||
                    boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
 | 
			
		||||
                        ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
 | 
			
		||||
                    if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){
 | 
			
		||||
                        wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
 | 
			
		||||
                    }
 | 
			
		||||
                    if (finished && body.stream) {
 | 
			
		||||
                        ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        if (statusCode != 200) {
 | 
			
		||||
            LOG.error("Status code " + statusCode);
 | 
			
		||||
            throw new OllamaBaseException(responseBuffer.toString());
 | 
			
		||||
        } else {
 | 
			
		||||
            if(wantedToolsForStream != null) {
 | 
			
		||||
                ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
 | 
			
		||||
            }
 | 
			
		||||
            OllamaChatResult ollamaResult =
 | 
			
		||||
                    new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
 | 
			
		||||
            if (isVerbose()) LOG.info("Model response: " + ollamaResult);
 | 
			
		||||
            return ollamaResult;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import io.github.ollama4j.models.response.OllamaErrorResponse;
 | 
			
		||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
			
		||||
import io.github.ollama4j.utils.OllamaRequestBody;
 | 
			
		||||
import io.github.ollama4j.utils.Utils;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import org.slf4j.Logger;
 | 
			
		||||
import org.slf4j.LoggerFactory;
 | 
			
		||||
 | 
			
		||||
@@ -24,14 +25,15 @@ import java.util.Base64;
 | 
			
		||||
/**
 | 
			
		||||
 * Abstract helperclass to call the ollama api server.
 | 
			
		||||
 */
 | 
			
		||||
@Getter
 | 
			
		||||
public abstract class OllamaEndpointCaller {
 | 
			
		||||
 | 
			
		||||
    private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
 | 
			
		||||
 | 
			
		||||
    private String host;
 | 
			
		||||
    private BasicAuth basicAuth;
 | 
			
		||||
    private long requestTimeoutSeconds;
 | 
			
		||||
    private boolean verbose;
 | 
			
		||||
    private final String host;
 | 
			
		||||
    private final BasicAuth basicAuth;
 | 
			
		||||
    private final long requestTimeoutSeconds;
 | 
			
		||||
    private final boolean verbose;
 | 
			
		||||
 | 
			
		||||
    public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
 | 
			
		||||
        this.host = host;
 | 
			
		||||
@@ -45,80 +47,13 @@ public abstract class OllamaEndpointCaller {
 | 
			
		||||
    protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
 | 
			
		||||
     *
 | 
			
		||||
     * @param body POST body payload
 | 
			
		||||
     * @return result answer given by the assistant
 | 
			
		||||
     * @throws OllamaBaseException  any response code than 200 has been returned
 | 
			
		||||
     * @throws IOException          in case the responseStream can not be read
 | 
			
		||||
     * @throws InterruptedException in case the server is not reachable or network issues happen
 | 
			
		||||
     */
 | 
			
		||||
    public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        // Create Request
 | 
			
		||||
        long startTime = System.currentTimeMillis();
 | 
			
		||||
        HttpClient httpClient = HttpClient.newHttpClient();
 | 
			
		||||
        URI uri = URI.create(this.host + getEndpointSuffix());
 | 
			
		||||
        HttpRequest.Builder requestBuilder =
 | 
			
		||||
                getRequestBuilderDefault(uri)
 | 
			
		||||
                        .POST(
 | 
			
		||||
                                body.getBodyPublisher());
 | 
			
		||||
        HttpRequest request = requestBuilder.build();
 | 
			
		||||
        if (this.verbose) LOG.info("Asking model: " + body.toString());
 | 
			
		||||
        HttpResponse<InputStream> response =
 | 
			
		||||
                httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
 | 
			
		||||
 | 
			
		||||
        int statusCode = response.statusCode();
 | 
			
		||||
        InputStream responseBodyStream = response.body();
 | 
			
		||||
        StringBuilder responseBuffer = new StringBuilder();
 | 
			
		||||
        try (BufferedReader reader =
 | 
			
		||||
                     new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
 | 
			
		||||
            String line;
 | 
			
		||||
            while ((line = reader.readLine()) != null) {
 | 
			
		||||
                if (statusCode == 404) {
 | 
			
		||||
                    LOG.warn("Status code: 404 (Not Found)");
 | 
			
		||||
                    OllamaErrorResponse ollamaResponseModel =
 | 
			
		||||
                            Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
 | 
			
		||||
                    responseBuffer.append(ollamaResponseModel.getError());
 | 
			
		||||
                } else if (statusCode == 401) {
 | 
			
		||||
                    LOG.warn("Status code: 401 (Unauthorized)");
 | 
			
		||||
                    OllamaErrorResponse ollamaResponseModel =
 | 
			
		||||
                            Utils.getObjectMapper()
 | 
			
		||||
                                    .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
 | 
			
		||||
                    responseBuffer.append(ollamaResponseModel.getError());
 | 
			
		||||
                } else if (statusCode == 400) {
 | 
			
		||||
                    LOG.warn("Status code: 400 (Bad Request)");
 | 
			
		||||
                    OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
 | 
			
		||||
                            OllamaErrorResponse.class);
 | 
			
		||||
                    responseBuffer.append(ollamaResponseModel.getError());
 | 
			
		||||
                } else {
 | 
			
		||||
                    boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
 | 
			
		||||
                    if (finished) {
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (statusCode != 200) {
 | 
			
		||||
            LOG.error("Status code " + statusCode);
 | 
			
		||||
            throw new OllamaBaseException(responseBuffer.toString());
 | 
			
		||||
        } else {
 | 
			
		||||
            long endTime = System.currentTimeMillis();
 | 
			
		||||
            OllamaResult ollamaResult =
 | 
			
		||||
                    new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
 | 
			
		||||
            if (verbose) LOG.info("Model response: " + ollamaResult);
 | 
			
		||||
            return ollamaResult;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get default request builder.
 | 
			
		||||
     *
 | 
			
		||||
     * @param uri URI to get a HttpRequest.Builder
 | 
			
		||||
     * @return HttpRequest.Builder
 | 
			
		||||
     */
 | 
			
		||||
    private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
 | 
			
		||||
    protected HttpRequest.Builder getRequestBuilderDefault(URI uri) {
 | 
			
		||||
        HttpRequest.Builder requestBuilder =
 | 
			
		||||
                HttpRequest.newBuilder(uri)
 | 
			
		||||
                        .header("Content-Type", "application/json")
 | 
			
		||||
@@ -134,7 +69,7 @@ public abstract class OllamaEndpointCaller {
 | 
			
		||||
     *
 | 
			
		||||
     * @return basic authentication header value (encoded credentials)
 | 
			
		||||
     */
 | 
			
		||||
    private String getBasicAuthHeaderValue() {
 | 
			
		||||
    protected String getBasicAuthHeaderValue() {
 | 
			
		||||
        String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
 | 
			
		||||
        return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
 | 
			
		||||
    }
 | 
			
		||||
@@ -144,7 +79,7 @@ public abstract class OllamaEndpointCaller {
 | 
			
		||||
     *
 | 
			
		||||
     * @return true when Basic Auth credentials set
 | 
			
		||||
     */
 | 
			
		||||
    private boolean isBasicAuthCredentialsSet() {
 | 
			
		||||
    protected boolean isBasicAuthCredentialsSet() {
 | 
			
		||||
        return this.basicAuth != null;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package io.github.ollama4j.models.request;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
			
		||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
 | 
			
		||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
			
		||||
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
 | 
			
		||||
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
 | 
			
		||||
@@ -11,7 +12,15 @@ import io.github.ollama4j.utils.Utils;
 | 
			
		||||
import org.slf4j.Logger;
 | 
			
		||||
import org.slf4j.LoggerFactory;
 | 
			
		||||
 | 
			
		||||
import java.io.BufferedReader;
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import java.io.InputStream;
 | 
			
		||||
import java.io.InputStreamReader;
 | 
			
		||||
import java.net.URI;
 | 
			
		||||
import java.net.http.HttpClient;
 | 
			
		||||
import java.net.http.HttpRequest;
 | 
			
		||||
import java.net.http.HttpResponse;
 | 
			
		||||
import java.nio.charset.StandardCharsets;
 | 
			
		||||
 | 
			
		||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
 | 
			
		||||
 | 
			
		||||
@@ -46,6 +55,73 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
 | 
			
		||||
    public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
 | 
			
		||||
            throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        streamObserver = new OllamaGenerateStreamObserver(streamHandler);
 | 
			
		||||
        return super.callSync(body);
 | 
			
		||||
        return callSync(body);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
 | 
			
		||||
     *
 | 
			
		||||
     * @param body POST body payload
 | 
			
		||||
     * @return result answer given by the assistant
 | 
			
		||||
     * @throws OllamaBaseException  any response code than 200 has been returned
 | 
			
		||||
     * @throws IOException          in case the responseStream can not be read
 | 
			
		||||
     * @throws InterruptedException in case the server is not reachable or network issues happen
 | 
			
		||||
     */
 | 
			
		||||
    public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        // Create Request
 | 
			
		||||
        long startTime = System.currentTimeMillis();
 | 
			
		||||
        HttpClient httpClient = HttpClient.newHttpClient();
 | 
			
		||||
        URI uri = URI.create(getHost() + getEndpointSuffix());
 | 
			
		||||
        HttpRequest.Builder requestBuilder =
 | 
			
		||||
                getRequestBuilderDefault(uri)
 | 
			
		||||
                        .POST(
 | 
			
		||||
                                body.getBodyPublisher());
 | 
			
		||||
        HttpRequest request = requestBuilder.build();
 | 
			
		||||
        if (isVerbose()) LOG.info("Asking model: " + body.toString());
 | 
			
		||||
        HttpResponse<InputStream> response =
 | 
			
		||||
                httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
 | 
			
		||||
 | 
			
		||||
        int statusCode = response.statusCode();
 | 
			
		||||
        InputStream responseBodyStream = response.body();
 | 
			
		||||
        StringBuilder responseBuffer = new StringBuilder();
 | 
			
		||||
        try (BufferedReader reader =
 | 
			
		||||
                     new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
 | 
			
		||||
            String line;
 | 
			
		||||
            while ((line = reader.readLine()) != null) {
 | 
			
		||||
                if (statusCode == 404) {
 | 
			
		||||
                    LOG.warn("Status code: 404 (Not Found)");
 | 
			
		||||
                    OllamaErrorResponse ollamaResponseModel =
 | 
			
		||||
                            Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
 | 
			
		||||
                    responseBuffer.append(ollamaResponseModel.getError());
 | 
			
		||||
                } else if (statusCode == 401) {
 | 
			
		||||
                    LOG.warn("Status code: 401 (Unauthorized)");
 | 
			
		||||
                    OllamaErrorResponse ollamaResponseModel =
 | 
			
		||||
                            Utils.getObjectMapper()
 | 
			
		||||
                                    .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
 | 
			
		||||
                    responseBuffer.append(ollamaResponseModel.getError());
 | 
			
		||||
                } else if (statusCode == 400) {
 | 
			
		||||
                    LOG.warn("Status code: 400 (Bad Request)");
 | 
			
		||||
                    OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
 | 
			
		||||
                            OllamaErrorResponse.class);
 | 
			
		||||
                    responseBuffer.append(ollamaResponseModel.getError());
 | 
			
		||||
                } else {
 | 
			
		||||
                    boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
 | 
			
		||||
                    if (finished) {
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (statusCode != 200) {
 | 
			
		||||
            LOG.error("Status code " + statusCode);
 | 
			
		||||
            throw new OllamaBaseException(responseBuffer.toString());
 | 
			
		||||
        } else {
 | 
			
		||||
            long endTime = System.currentTimeMillis();
 | 
			
		||||
            OllamaResult ollamaResult =
 | 
			
		||||
                    new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
 | 
			
		||||
            if (isVerbose()) LOG.info("Model response: " + ollamaResult);
 | 
			
		||||
            return ollamaResult;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,16 @@
 | 
			
		||||
package io.github.ollama4j.tools;
 | 
			
		||||
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
public class OllamaToolCallsFunction
 | 
			
		||||
{
 | 
			
		||||
    private String name;
 | 
			
		||||
    private Map<String,Object> arguments;
 | 
			
		||||
}
 | 
			
		||||
@@ -1,16 +1,22 @@
 | 
			
		||||
package io.github.ollama4j.tools;
 | 
			
		||||
 | 
			
		||||
import java.util.Collection;
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
public class ToolRegistry {
 | 
			
		||||
    private final Map<String, ToolFunction> functionMap = new HashMap<>();
 | 
			
		||||
    private final Map<String, Tools.ToolSpecification> tools = new HashMap<>();
 | 
			
		||||
 | 
			
		||||
    public ToolFunction getFunction(String name) {
 | 
			
		||||
        return functionMap.get(name);
 | 
			
		||||
    public ToolFunction getToolFunction(String name) {
 | 
			
		||||
        final Tools.ToolSpecification toolSpecification = tools.get(name);
 | 
			
		||||
        return toolSpecification !=null ? toolSpecification.getToolFunction() : null ;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public void addFunction(String name, ToolFunction function) {
 | 
			
		||||
        functionMap.put(name, function);
 | 
			
		||||
    public void addTool (String name, Tools.ToolSpecification specification) {
 | 
			
		||||
        tools.put(name, specification);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public Collection<Tools.ToolSpecification> getRegisteredSpecs(){
 | 
			
		||||
        return tools.values();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,8 +6,10 @@ import com.fasterxml.jackson.annotation.JsonInclude;
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import io.github.ollama4j.utils.Utils;
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
@@ -20,17 +22,23 @@ public class Tools {
 | 
			
		||||
    public static class ToolSpecification {
 | 
			
		||||
        private String functionName;
 | 
			
		||||
        private String functionDescription;
 | 
			
		||||
        private Map<String, PromptFuncDefinition.Property> properties;
 | 
			
		||||
        private ToolFunction toolDefinition;
 | 
			
		||||
        private PromptFuncDefinition toolPrompt;
 | 
			
		||||
        private ToolFunction toolFunction;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Data
 | 
			
		||||
    @JsonIgnoreProperties(ignoreUnknown = true)
 | 
			
		||||
    @Builder
 | 
			
		||||
    @NoArgsConstructor
 | 
			
		||||
    @AllArgsConstructor
 | 
			
		||||
    public static class PromptFuncDefinition {
 | 
			
		||||
        private String type;
 | 
			
		||||
        private PromptFuncSpec function;
 | 
			
		||||
 | 
			
		||||
        @Data
 | 
			
		||||
        @Builder
 | 
			
		||||
        @NoArgsConstructor
 | 
			
		||||
        @AllArgsConstructor
 | 
			
		||||
        public static class PromptFuncSpec {
 | 
			
		||||
            private String name;
 | 
			
		||||
            private String description;
 | 
			
		||||
@@ -38,6 +46,9 @@ public class Tools {
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Data
 | 
			
		||||
        @Builder
 | 
			
		||||
        @NoArgsConstructor
 | 
			
		||||
        @AllArgsConstructor
 | 
			
		||||
        public static class Parameters {
 | 
			
		||||
            private String type;
 | 
			
		||||
            private Map<String, Property> properties;
 | 
			
		||||
@@ -46,6 +57,8 @@ public class Tools {
 | 
			
		||||
 | 
			
		||||
        @Data
 | 
			
		||||
        @Builder
 | 
			
		||||
        @NoArgsConstructor
 | 
			
		||||
        @AllArgsConstructor
 | 
			
		||||
        public static class Property {
 | 
			
		||||
            private String type;
 | 
			
		||||
            private String description;
 | 
			
		||||
@@ -94,10 +107,10 @@ public class Tools {
 | 
			
		||||
 | 
			
		||||
            PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
 | 
			
		||||
            parameters.setType("object");
 | 
			
		||||
            parameters.setProperties(spec.getProperties());
 | 
			
		||||
            parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties());
 | 
			
		||||
 | 
			
		||||
            List<String> requiredValues = new ArrayList<>();
 | 
			
		||||
            for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) {
 | 
			
		||||
            for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) {
 | 
			
		||||
                if (p.getValue().isRequired()) {
 | 
			
		||||
                    requiredValues.add(p.getKey());
 | 
			
		||||
                }
 | 
			
		||||
 
 | 
			
		||||
@@ -2,14 +2,13 @@ package io.github.ollama4j.integrationtests;
 | 
			
		||||
 | 
			
		||||
import io.github.ollama4j.OllamaAPI;
 | 
			
		||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
			
		||||
import io.github.ollama4j.models.chat.*;
 | 
			
		||||
import io.github.ollama4j.models.response.ModelDetail;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
 | 
			
		||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatResult;
 | 
			
		||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
 | 
			
		||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
			
		||||
import io.github.ollama4j.tools.ToolFunction;
 | 
			
		||||
import io.github.ollama4j.tools.Tools;
 | 
			
		||||
import io.github.ollama4j.utils.OptionsBuilder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import org.junit.jupiter.api.BeforeEach;
 | 
			
		||||
@@ -24,9 +23,7 @@ import java.io.InputStream;
 | 
			
		||||
import java.net.ConnectException;
 | 
			
		||||
import java.net.URISyntaxException;
 | 
			
		||||
import java.net.http.HttpConnectTimeoutException;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Objects;
 | 
			
		||||
import java.util.Properties;
 | 
			
		||||
import java.util.*;
 | 
			
		||||
 | 
			
		||||
import static org.junit.jupiter.api.Assertions.*;
 | 
			
		||||
 | 
			
		||||
@@ -47,6 +44,7 @@ class TestRealAPIs {
 | 
			
		||||
        config = new Config();
 | 
			
		||||
        ollamaAPI = new OllamaAPI(config.getOllamaURL());
 | 
			
		||||
        ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
 | 
			
		||||
        ollamaAPI.setVerbose(true);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
@@ -196,7 +194,9 @@ class TestRealAPIs {
 | 
			
		||||
 | 
			
		||||
            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
			
		||||
            assertNotNull(chatResult);
 | 
			
		||||
            assertFalse(chatResult.getResponse().isBlank());
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel());
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
			
		||||
            assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
 | 
			
		||||
            assertEquals(4, chatResult.getChatHistory().size());
 | 
			
		||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
            fail(e);
 | 
			
		||||
@@ -217,14 +217,134 @@ class TestRealAPIs {
 | 
			
		||||
 | 
			
		||||
            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
			
		||||
            assertNotNull(chatResult);
 | 
			
		||||
            assertFalse(chatResult.getResponse().isBlank());
 | 
			
		||||
            assertTrue(chatResult.getResponse().startsWith("NI"));
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel());
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
			
		||||
            assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
 | 
			
		||||
            assertTrue(chatResult.getResponseModel().getMessage().getContent().startsWith("NI"));
 | 
			
		||||
            assertEquals(3, chatResult.getChatHistory().size());
 | 
			
		||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
            fail(e);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    @Order(3)
 | 
			
		||||
    void testChatWithTools() {
 | 
			
		||||
        testEndpointReachability();
 | 
			
		||||
        try {
 | 
			
		||||
            ollamaAPI.setVerbose(true);
 | 
			
		||||
            OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
 | 
			
		||||
 | 
			
		||||
            final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
 | 
			
		||||
                    .functionName("get-employee-details")
 | 
			
		||||
                    .functionDescription("Get employee details from the database")
 | 
			
		||||
                    .toolPrompt(
 | 
			
		||||
                            Tools.PromptFuncDefinition.builder().type("function").function(
 | 
			
		||||
                                    Tools.PromptFuncDefinition.PromptFuncSpec.builder()
 | 
			
		||||
                                            .name("get-employee-details")
 | 
			
		||||
                                            .description("Get employee details from the database")
 | 
			
		||||
                                            .parameters(
 | 
			
		||||
                                                    Tools.PromptFuncDefinition.Parameters.builder()
 | 
			
		||||
                                                            .type("object")
 | 
			
		||||
                                                            .properties(
 | 
			
		||||
                                                            new Tools.PropsBuilder()
 | 
			
		||||
                                                                    .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
 | 
			
		||||
                                                                    .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
 | 
			
		||||
                                                                    .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
 | 
			
		||||
                                                                    .build()
 | 
			
		||||
                                                    )
 | 
			
		||||
                                                            .required(List.of("employee-name"))
 | 
			
		||||
                                                            .build()
 | 
			
		||||
                                            ).build()
 | 
			
		||||
                            ).build()
 | 
			
		||||
                    )
 | 
			
		||||
                    .toolFunction(new DBQueryFunction())
 | 
			
		||||
                    .build();
 | 
			
		||||
 | 
			
		||||
            ollamaAPI.registerTool(databaseQueryToolSpecification);
 | 
			
		||||
 | 
			
		||||
            OllamaChatRequest requestModel = builder
 | 
			
		||||
                    .withMessage(OllamaChatMessageRole.USER,
 | 
			
		||||
                            "Give me the ID of the employee named 'Rahul Kumar'?")
 | 
			
		||||
                    .build();
 | 
			
		||||
 | 
			
		||||
            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
			
		||||
            assertNotNull(chatResult);
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel());
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
			
		||||
            assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
 | 
			
		||||
            List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
 | 
			
		||||
            assertEquals(1, toolCalls.size());
 | 
			
		||||
            assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName());
 | 
			
		||||
            assertEquals(1, toolCalls.get(0).getFunction().getArguments().size());
 | 
			
		||||
            Object employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name");
 | 
			
		||||
            assertNotNull(employeeName);
 | 
			
		||||
            assertEquals("Rahul Kumar",employeeName);
 | 
			
		||||
            assertTrue(chatResult.getChatHistory().size()>2);
 | 
			
		||||
            List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
 | 
			
		||||
            assertNull(finalToolCalls);
 | 
			
		||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
            fail(e);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    @Order(3)
 | 
			
		||||
    void testChatWithToolsAndStream() {
 | 
			
		||||
        testEndpointReachability();
 | 
			
		||||
        try {
 | 
			
		||||
            OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
 | 
			
		||||
            final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
 | 
			
		||||
                    .functionName("get-employee-details")
 | 
			
		||||
                    .functionDescription("Get employee details from the database")
 | 
			
		||||
                    .toolPrompt(
 | 
			
		||||
                            Tools.PromptFuncDefinition.builder().type("function").function(
 | 
			
		||||
                                    Tools.PromptFuncDefinition.PromptFuncSpec.builder()
 | 
			
		||||
                                            .name("get-employee-details")
 | 
			
		||||
                                            .description("Get employee details from the database")
 | 
			
		||||
                                            .parameters(
 | 
			
		||||
                                                    Tools.PromptFuncDefinition.Parameters.builder()
 | 
			
		||||
                                                            .type("object")
 | 
			
		||||
                                                            .properties(
 | 
			
		||||
                                                                    new Tools.PropsBuilder()
 | 
			
		||||
                                                                            .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
 | 
			
		||||
                                                                            .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
 | 
			
		||||
                                                                            .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
 | 
			
		||||
                                                                            .build()
 | 
			
		||||
                                                            )
 | 
			
		||||
                                                            .required(List.of("employee-name"))
 | 
			
		||||
                                                            .build()
 | 
			
		||||
                                            ).build()
 | 
			
		||||
                            ).build()
 | 
			
		||||
                    )
 | 
			
		||||
                    .toolFunction(new DBQueryFunction())
 | 
			
		||||
                    .build();
 | 
			
		||||
 | 
			
		||||
            ollamaAPI.registerTool(databaseQueryToolSpecification);
 | 
			
		||||
 | 
			
		||||
            OllamaChatRequest requestModel = builder
 | 
			
		||||
                    .withMessage(OllamaChatMessageRole.USER,
 | 
			
		||||
                            "Give me the ID of the employee named 'Rahul Kumar'?")
 | 
			
		||||
                    .build();
 | 
			
		||||
 | 
			
		||||
            StringBuffer sb = new StringBuffer();
 | 
			
		||||
 | 
			
		||||
            OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
 | 
			
		||||
                LOG.info(s);
 | 
			
		||||
                String substring = s.substring(sb.toString().length());
 | 
			
		||||
                LOG.info(substring);
 | 
			
		||||
                sb.append(substring);
 | 
			
		||||
            });
 | 
			
		||||
            assertNotNull(chatResult);
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel());
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel().getMessage().getContent());
 | 
			
		||||
            assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
 | 
			
		||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
            fail(e);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    @Order(3)
 | 
			
		||||
    void testChatWithStream() {
 | 
			
		||||
@@ -244,7 +364,10 @@ class TestRealAPIs {
 | 
			
		||||
                sb.append(substring);
 | 
			
		||||
            });
 | 
			
		||||
            assertNotNull(chatResult);
 | 
			
		||||
            assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel());
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel().getMessage().getContent());
 | 
			
		||||
            assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
 | 
			
		||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
            fail(e);
 | 
			
		||||
        }
 | 
			
		||||
@@ -258,12 +381,12 @@ class TestRealAPIs {
 | 
			
		||||
            OllamaChatRequestBuilder builder =
 | 
			
		||||
                    OllamaChatRequestBuilder.getInstance(config.getImageModel());
 | 
			
		||||
            OllamaChatRequest requestModel =
 | 
			
		||||
                    builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
 | 
			
		||||
                    builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
 | 
			
		||||
                            List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
 | 
			
		||||
 | 
			
		||||
            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
			
		||||
            assertNotNull(chatResult);
 | 
			
		||||
            assertNotNull(chatResult.getResponse());
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel());
 | 
			
		||||
 | 
			
		||||
            builder.reset();
 | 
			
		||||
 | 
			
		||||
@@ -273,7 +396,7 @@ class TestRealAPIs {
 | 
			
		||||
 | 
			
		||||
            chatResult = ollamaAPI.chat(requestModel);
 | 
			
		||||
            assertNotNull(chatResult);
 | 
			
		||||
            assertNotNull(chatResult.getResponse());
 | 
			
		||||
            assertNotNull(chatResult.getResponseModel());
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
@@ -287,7 +410,7 @@ class TestRealAPIs {
 | 
			
		||||
        testEndpointReachability();
 | 
			
		||||
        try {
 | 
			
		||||
            OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
 | 
			
		||||
            OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
 | 
			
		||||
            OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
 | 
			
		||||
                            "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
 | 
			
		||||
                    .build();
 | 
			
		||||
 | 
			
		||||
@@ -380,6 +503,14 @@ class TestRealAPIs {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class DBQueryFunction implements ToolFunction {
 | 
			
		||||
    @Override
 | 
			
		||||
    public Object apply(Map<String, Object> arguments) {
 | 
			
		||||
        // perform DB operations here
 | 
			
		||||
        return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone"));
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
class Config {
 | 
			
		||||
    private String ollamaURL;
 | 
			
		||||
@@ -404,4 +535,6 @@ class Config {
 | 
			
		||||
            throw new RuntimeException("Error loading properties", e);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
 | 
			
		||||
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
 | 
			
		||||
 | 
			
		||||
import java.io.File;
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
 | 
			
		||||
@@ -42,7 +43,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testRequestWithMessageAndImage() {
 | 
			
		||||
        OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
 | 
			
		||||
        OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", Collections.emptyList(),
 | 
			
		||||
                List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
 | 
			
		||||
        String jsonRequest = serialize(req);
 | 
			
		||||
        assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);
 | 
			
		||||
 
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
ollama.url=http://localhost:11434
 | 
			
		||||
ollama.model=qwen:0.5b
 | 
			
		||||
ollama.model.image=llava
 | 
			
		||||
ollama.model=llama3.2:1b
 | 
			
		||||
ollama.model.image=llava:latest
 | 
			
		||||
ollama.request-timeout-seconds=120
 | 
			
		||||
		Reference in New Issue
	
	Block a user