mirror of
				https://github.com/amithkoujalgi/ollama4j.git
				synced 2025-11-03 18:10:42 +01:00 
			
		
		
		
	Merge pull request #82 from AgentSchmecker/feature/toolextension_for_chat_model
Enable chat API to use Tools
This commit is contained in:
		@@ -33,7 +33,7 @@ public class Main {
 | 
				
			|||||||
        // start conversation with model
 | 
					        // start conversation with model
 | 
				
			||||||
        OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
					        OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        System.out.println("First answer: " + chatResult.getResponse());
 | 
					        System.out.println("First answer: " + chatResult.getResponseModel().getMessage().getContent());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // create next userQuestion
 | 
					        // create next userQuestion
 | 
				
			||||||
        requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is the second largest city?").build();
 | 
					        requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is the second largest city?").build();
 | 
				
			||||||
@@ -41,7 +41,7 @@ public class Main {
 | 
				
			|||||||
        // "continue" conversation with model
 | 
					        // "continue" conversation with model
 | 
				
			||||||
        chatResult = ollamaAPI.chat(requestModel);
 | 
					        chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        System.out.println("Second answer: " + chatResult.getResponse());
 | 
					        System.out.println("Second answer: " + chatResult.getResponseModel().getMessage().getContent());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        System.out.println("Chat History: " + chatResult.getChatHistory());
 | 
					        System.out.println("Chat History: " + chatResult.getChatHistory());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -205,7 +205,7 @@ public class Main {
 | 
				
			|||||||
        // start conversation with model
 | 
					        // start conversation with model
 | 
				
			||||||
        OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
					        OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        System.out.println(chatResult.getResponse());
 | 
					        System.out.println(chatResult.getResponseModel());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -244,7 +244,7 @@ public class Main {
 | 
				
			|||||||
                                new File("/path/to/image"))).build();
 | 
					                                new File("/path/to/image"))).build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
					        OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
        System.out.println("First answer: " + chatResult.getResponse());
 | 
					        System.out.println("First answer: " + chatResult.getResponseModel());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        builder.reset();
 | 
					        builder.reset();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -254,7 +254,7 @@ public class Main {
 | 
				
			|||||||
                        .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
 | 
					                        .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        chatResult = ollamaAPI.chat(requestModel);
 | 
					        chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
        System.out.println("Second answer: " + chatResult.getResponse());
 | 
					        System.out.println("Second answer: " + chatResult.getResponseModel());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -345,6 +345,125 @@ Rahul Kumar, Address: King St, Hyderabad, India, Phone: 9876543210}`
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
::::
 | 
					::::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Using tools in Chat-API
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Instead of using the specific `ollamaAPI.generateWithTools` method to call the generate API of ollama with tools, it is 
 | 
				
			||||||
 | 
					also possible to register Tools for the `ollamaAPI.chat` methods. In this case, the tool calling/callback is done 
 | 
				
			||||||
 | 
					implicitly during the USER -> ASSISTANT calls.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					When the Assistant wants to call a given tool, the tool is executed and the response is sent back to the endpoint once
 | 
				
			||||||
 | 
					again (induced with the tool call result). 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### Sample:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The following shows a sample of an integration test that defines a method specified like the tool-specs above, registers
 | 
				
			||||||
 | 
					the tool on the ollamaAPI and then simply calls the chat-API. All intermediate tool calling is wrapped inside the api 
 | 
				
			||||||
 | 
					call.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```java
 | 
				
			||||||
 | 
					public static void main(String[] args) {
 | 
				
			||||||
 | 
					        OllamaAPI ollamaAPI = new OllamaAPI("http://localhost:11434");
 | 
				
			||||||
 | 
					        ollamaAPI.setVerbose(true);
 | 
				
			||||||
 | 
					        OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("llama3.2:1b");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
 | 
				
			||||||
 | 
					                .functionName("get-employee-details")
 | 
				
			||||||
 | 
					                .functionDescription("Get employee details from the database")
 | 
				
			||||||
 | 
					                .toolPrompt(
 | 
				
			||||||
 | 
					                        Tools.PromptFuncDefinition.builder().type("function").function(
 | 
				
			||||||
 | 
					                                Tools.PromptFuncDefinition.PromptFuncSpec.builder()
 | 
				
			||||||
 | 
					                                        .name("get-employee-details")
 | 
				
			||||||
 | 
					                                        .description("Get employee details from the database")
 | 
				
			||||||
 | 
					                                        .parameters(
 | 
				
			||||||
 | 
					                                                Tools.PromptFuncDefinition.Parameters.builder()
 | 
				
			||||||
 | 
					                                                        .type("object")
 | 
				
			||||||
 | 
					                                                        .properties(
 | 
				
			||||||
 | 
					                                                                new Tools.PropsBuilder()
 | 
				
			||||||
 | 
					                                                                        .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
 | 
				
			||||||
 | 
					                                                                        .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
 | 
				
			||||||
 | 
					                                                                        .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
 | 
				
			||||||
 | 
					                                                                        .build()
 | 
				
			||||||
 | 
					                                                        )
 | 
				
			||||||
 | 
					                                                        .required(List.of("employee-name"))
 | 
				
			||||||
 | 
					                                                        .build()
 | 
				
			||||||
 | 
					                                        ).build()
 | 
				
			||||||
 | 
					                        ).build()
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                .toolFunction(new DBQueryFunction())
 | 
				
			||||||
 | 
					                .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ollamaAPI.registerTool(databaseQueryToolSpecification);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        OllamaChatRequest requestModel = builder
 | 
				
			||||||
 | 
					                .withMessage(OllamaChatMessageRole.USER,
 | 
				
			||||||
 | 
					                        "Give me the ID of the employee named 'Rahul Kumar'?")
 | 
				
			||||||
 | 
					                .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					A typical final response of the above could be:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```json 
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "chatHistory" : [
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					    "role" : "user",
 | 
				
			||||||
 | 
					    "content" : "Give me the ID of the employee named 'Rahul Kumar'?",
 | 
				
			||||||
 | 
					    "images" : null,
 | 
				
			||||||
 | 
					    "tool_calls" : [ ]
 | 
				
			||||||
 | 
					  }, {
 | 
				
			||||||
 | 
					    "role" : "assistant",
 | 
				
			||||||
 | 
					    "content" : "",
 | 
				
			||||||
 | 
					    "images" : null,
 | 
				
			||||||
 | 
					    "tool_calls" : [ {
 | 
				
			||||||
 | 
					      "function" : {
 | 
				
			||||||
 | 
					        "name" : "get-employee-details",
 | 
				
			||||||
 | 
					        "arguments" : {
 | 
				
			||||||
 | 
					          "employee-name" : "Rahul Kumar"
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    } ]
 | 
				
			||||||
 | 
					  }, {
 | 
				
			||||||
 | 
					    "role" : "tool",
 | 
				
			||||||
 | 
					    "content" : "[TOOL_RESULTS]get-employee-details([employee-name]) : Employee Details {ID: b4bf186c-2ee1-44cc-8856-53b8b6a50f85, Name: Rahul Kumar, Address: null, Phone: null}[/TOOL_RESULTS]",
 | 
				
			||||||
 | 
					    "images" : null,
 | 
				
			||||||
 | 
					    "tool_calls" : null
 | 
				
			||||||
 | 
					  }, {
 | 
				
			||||||
 | 
					    "role" : "assistant",
 | 
				
			||||||
 | 
					    "content" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
 | 
				
			||||||
 | 
					    "images" : null,
 | 
				
			||||||
 | 
					    "tool_calls" : null
 | 
				
			||||||
 | 
					  } ],
 | 
				
			||||||
 | 
					  "responseModel" : {
 | 
				
			||||||
 | 
					    "model" : "llama3.2:1b",
 | 
				
			||||||
 | 
					    "message" : {
 | 
				
			||||||
 | 
					      "role" : "assistant",
 | 
				
			||||||
 | 
					      "content" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
 | 
				
			||||||
 | 
					      "images" : null,
 | 
				
			||||||
 | 
					      "tool_calls" : null
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "done" : true,
 | 
				
			||||||
 | 
					    "error" : null,
 | 
				
			||||||
 | 
					    "context" : null,
 | 
				
			||||||
 | 
					    "created_at" : "2024-12-09T22:23:00.4940078Z",
 | 
				
			||||||
 | 
					    "done_reason" : "stop",
 | 
				
			||||||
 | 
					    "total_duration" : 2313709900,
 | 
				
			||||||
 | 
					    "load_duration" : 14494700,
 | 
				
			||||||
 | 
					    "prompt_eval_duration" : 772000000,
 | 
				
			||||||
 | 
					    "eval_duration" : 1188000000,
 | 
				
			||||||
 | 
					    "prompt_eval_count" : 166,
 | 
				
			||||||
 | 
					    "eval_count" : 41
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "response" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
 | 
				
			||||||
 | 
					  "httpStatusCode" : 200,
 | 
				
			||||||
 | 
					  "responseTime" : 2313709900
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This tool calling can also be done using the streaming API.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Potential Improvements
 | 
					### Potential Improvements
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool
 | 
					Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -59,6 +59,10 @@ public class OllamaAPI {
 | 
				
			|||||||
     */
 | 
					     */
 | 
				
			||||||
    @Setter
 | 
					    @Setter
 | 
				
			||||||
    private boolean verbose = true;
 | 
					    private boolean verbose = true;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Setter
 | 
				
			||||||
 | 
					    private int maxChatToolCallRetries = 3;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private BasicAuth basicAuth;
 | 
					    private BasicAuth basicAuth;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private final ToolRegistry toolRegistry = new ToolRegistry();
 | 
					    private final ToolRegistry toolRegistry = new ToolRegistry();
 | 
				
			||||||
@@ -767,18 +771,44 @@ public class OllamaAPI {
 | 
				
			|||||||
     */
 | 
					     */
 | 
				
			||||||
    public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
 | 
					    public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
 | 
				
			||||||
        OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
					        OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
				
			||||||
        OllamaResult result;
 | 
					        OllamaChatResult result;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // add all registered tools to Request
 | 
				
			||||||
 | 
					        request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (streamHandler != null) {
 | 
					        if (streamHandler != null) {
 | 
				
			||||||
            request.setStream(true);
 | 
					            request.setStream(true);
 | 
				
			||||||
            result = requestCaller.call(request, streamHandler);
 | 
					            result = requestCaller.call(request, streamHandler);
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            result = requestCaller.callSync(request);
 | 
					            result = requestCaller.callSync(request);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
 | 
					
 | 
				
			||||||
 | 
					        // check if toolCallIsWanted
 | 
				
			||||||
 | 
					        List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
				
			||||||
 | 
					        int toolCallTries = 0;
 | 
				
			||||||
 | 
					        while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries){
 | 
				
			||||||
 | 
					            for (OllamaChatToolCalls toolCall : toolCalls){
 | 
				
			||||||
 | 
					                String toolName = toolCall.getFunction().getName();
 | 
				
			||||||
 | 
					                ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
 | 
				
			||||||
 | 
					                Map<String, Object> arguments = toolCall.getFunction().getArguments();
 | 
				
			||||||
 | 
					                Object res = toolFunction.apply(arguments);
 | 
				
			||||||
 | 
					                request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (streamHandler != null) {
 | 
				
			||||||
 | 
					                result = requestCaller.call(request, streamHandler);
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					                result = requestCaller.callSync(request);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
				
			||||||
 | 
					            toolCallTries++;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return result;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public void registerTool(Tools.ToolSpecification toolSpecification) {
 | 
					    public void registerTool(Tools.ToolSpecification toolSpecification) {
 | 
				
			||||||
        toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
 | 
					        toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
@@ -871,7 +901,7 @@ public class OllamaAPI {
 | 
				
			|||||||
        try {
 | 
					        try {
 | 
				
			||||||
            String methodName = toolFunctionCallSpec.getName();
 | 
					            String methodName = toolFunctionCallSpec.getName();
 | 
				
			||||||
            Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
 | 
					            Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
 | 
				
			||||||
            ToolFunction function = toolRegistry.getFunction(methodName);
 | 
					            ToolFunction function = toolRegistry.getToolFunction(methodName);
 | 
				
			||||||
            if (verbose) {
 | 
					            if (verbose) {
 | 
				
			||||||
                logger.debug("Invoking function {} with arguments {}", methodName, arguments);
 | 
					                logger.debug("Invoking function {} with arguments {}", methodName, arguments);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ package io.github.ollama4j.models.chat;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
 | 
					import static io.github.ollama4j.utils.Utils.getObjectMapper;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
					import com.fasterxml.jackson.core.JsonProcessingException;
 | 
				
			||||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
 | 
					import com.fasterxml.jackson.databind.annotation.JsonSerialize;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -32,6 +33,8 @@ public class OllamaChatMessage {
 | 
				
			|||||||
    @NonNull
 | 
					    @NonNull
 | 
				
			||||||
    private String content;
 | 
					    private String content;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @JsonSerialize(using = FileToBase64Serializer.class)
 | 
					    @JsonSerialize(using = FileToBase64Serializer.class)
 | 
				
			||||||
    private List<byte[]> images;
 | 
					    private List<byte[]> images;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@ package io.github.ollama4j.models.chat;
 | 
				
			|||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import io.github.ollama4j.models.request.OllamaCommonRequest;
 | 
					import io.github.ollama4j.models.request.OllamaCommonRequest;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.Tools;
 | 
				
			||||||
import io.github.ollama4j.utils.OllamaRequestBody;
 | 
					import io.github.ollama4j.utils.OllamaRequestBody;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.Getter;
 | 
					import lombok.Getter;
 | 
				
			||||||
@@ -21,6 +22,8 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  private List<OllamaChatMessage> messages;
 | 
					  private List<OllamaChatMessage> messages;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  private List<Tools.PromptFuncDefinition> tools;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public OllamaChatRequest() {}
 | 
					  public OllamaChatRequest() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
 | 
					  public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,6 +10,7 @@ import java.io.IOException;
 | 
				
			|||||||
import java.net.URISyntaxException;
 | 
					import java.net.URISyntaxException;
 | 
				
			||||||
import java.nio.file.Files;
 | 
					import java.nio.file.Files;
 | 
				
			||||||
import java.util.ArrayList;
 | 
					import java.util.ArrayList;
 | 
				
			||||||
 | 
					import java.util.Collections;
 | 
				
			||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
import java.util.stream.Collectors;
 | 
					import java.util.stream.Collectors;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -38,7 +39,11 @@ public class OllamaChatRequestBuilder {
 | 
				
			|||||||
        request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
 | 
					        request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images) {
 | 
					    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){
 | 
				
			||||||
 | 
					        return withMessage(role,content, Collections.emptyList());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
 | 
				
			||||||
        List<OllamaChatMessage> messages = this.request.getMessages();
 | 
					        List<OllamaChatMessage> messages = this.request.getMessages();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        List<byte[]> binaryImages = images.stream().map(file -> {
 | 
					        List<byte[]> binaryImages = images.stream().map(file -> {
 | 
				
			||||||
@@ -50,11 +55,11 @@ public class OllamaChatRequestBuilder {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        }).collect(Collectors.toList());
 | 
					        }).collect(Collectors.toList());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        messages.add(new OllamaChatMessage(role, content, binaryImages));
 | 
					        messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
 | 
				
			||||||
        return this;
 | 
					        return this;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) {
 | 
					    public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List<OllamaChatToolCalls> toolCalls, String... imageUrls) {
 | 
				
			||||||
        List<OllamaChatMessage> messages = this.request.getMessages();
 | 
					        List<OllamaChatMessage> messages = this.request.getMessages();
 | 
				
			||||||
        List<byte[]> binaryImages = null;
 | 
					        List<byte[]> binaryImages = null;
 | 
				
			||||||
        if (imageUrls.length > 0) {
 | 
					        if (imageUrls.length > 0) {
 | 
				
			||||||
@@ -70,7 +75,7 @@ public class OllamaChatRequestBuilder {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        messages.add(new OllamaChatMessage(role, content, binaryImages));
 | 
					        messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
 | 
				
			||||||
        return this;
 | 
					        return this;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,28 +2,54 @@ package io.github.ollama4j.models.chat;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
					import com.fasterxml.jackson.core.JsonProcessingException;
 | 
				
			||||||
 | 
					import lombok.Getter;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import static io.github.ollama4j.utils.Utils.getObjectMapper;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
 | 
					 * Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
 | 
				
			||||||
 * {@link OllamaChatMessageRole#ASSISTANT} role.
 | 
					 * {@link OllamaChatMessageRole#ASSISTANT} role.
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
public class OllamaChatResult extends OllamaResult {
 | 
					@Getter
 | 
				
			||||||
 | 
					public class OllamaChatResult {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private List<OllamaChatMessage> chatHistory;
 | 
					    private List<OllamaChatMessage> chatHistory;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public OllamaChatResult(String response, long responseTime, int httpStatusCode, List<OllamaChatMessage> chatHistory) {
 | 
					    private OllamaChatResponseModel responseModel;
 | 
				
			||||||
        super(response, responseTime, httpStatusCode);
 | 
					
 | 
				
			||||||
 | 
					    public OllamaChatResult(OllamaChatResponseModel responseModel, List<OllamaChatMessage> chatHistory) {
 | 
				
			||||||
        this.chatHistory = chatHistory;
 | 
					        this.chatHistory = chatHistory;
 | 
				
			||||||
        appendAnswerToChatHistory(response);
 | 
					        this.responseModel = responseModel;
 | 
				
			||||||
 | 
					        appendAnswerToChatHistory(responseModel);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public List<OllamaChatMessage> getChatHistory() {
 | 
					    private void appendAnswerToChatHistory(OllamaChatResponseModel response) {
 | 
				
			||||||
        return chatHistory;
 | 
					        this.chatHistory.add(response.getMessage());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private void appendAnswerToChatHistory(String answer) {
 | 
					    @Override
 | 
				
			||||||
        OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer);
 | 
					    public String toString() {
 | 
				
			||||||
        this.chatHistory.add(assistantMessage);
 | 
					        try {
 | 
				
			||||||
 | 
					            return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
 | 
				
			||||||
 | 
					        } catch (JsonProcessingException e) {
 | 
				
			||||||
 | 
					            throw new RuntimeException(e);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Deprecated
 | 
				
			||||||
 | 
					    public String getResponse(){
 | 
				
			||||||
 | 
					        return responseModel != null ? responseModel.getMessage().getContent() : "";
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Deprecated
 | 
				
			||||||
 | 
					    public int getHttpStatusCode(){
 | 
				
			||||||
 | 
					        return 200;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Deprecated
 | 
				
			||||||
 | 
					    public long getResponseTime(){
 | 
				
			||||||
 | 
					        return responseModel != null ? responseModel.getTotalDuration() : 0L;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,16 @@
 | 
				
			|||||||
 | 
					package io.github.ollama4j.models.chat;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.OllamaToolCallsFunction;
 | 
				
			||||||
 | 
					import lombok.AllArgsConstructor;
 | 
				
			||||||
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Data
 | 
				
			||||||
 | 
					@NoArgsConstructor
 | 
				
			||||||
 | 
					@AllArgsConstructor
 | 
				
			||||||
 | 
					public class OllamaChatToolCalls {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private OllamaToolCallsFunction function;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -3,17 +3,24 @@ package io.github.ollama4j.models.request;
 | 
				
			|||||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
					import com.fasterxml.jackson.core.JsonProcessingException;
 | 
				
			||||||
import com.fasterxml.jackson.core.type.TypeReference;
 | 
					import com.fasterxml.jackson.core.type.TypeReference;
 | 
				
			||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
					import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
 | 
					import io.github.ollama4j.models.chat.*;
 | 
				
			||||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
					import io.github.ollama4j.models.response.OllamaErrorResponse;
 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
					import io.github.ollama4j.models.generate.OllamaStreamHandler;
 | 
				
			||||||
import io.github.ollama4j.utils.OllamaRequestBody;
 | 
					import io.github.ollama4j.tools.Tools;
 | 
				
			||||||
import io.github.ollama4j.utils.Utils;
 | 
					import io.github.ollama4j.utils.Utils;
 | 
				
			||||||
import org.slf4j.Logger;
 | 
					import org.slf4j.Logger;
 | 
				
			||||||
import org.slf4j.LoggerFactory;
 | 
					import org.slf4j.LoggerFactory;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.io.BufferedReader;
 | 
				
			||||||
import java.io.IOException;
 | 
					import java.io.IOException;
 | 
				
			||||||
 | 
					import java.io.InputStream;
 | 
				
			||||||
 | 
					import java.io.InputStreamReader;
 | 
				
			||||||
 | 
					import java.net.URI;
 | 
				
			||||||
 | 
					import java.net.http.HttpClient;
 | 
				
			||||||
 | 
					import java.net.http.HttpRequest;
 | 
				
			||||||
 | 
					import java.net.http.HttpResponse;
 | 
				
			||||||
 | 
					import java.nio.charset.StandardCharsets;
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Specialization class for requests
 | 
					 * Specialization class for requests
 | 
				
			||||||
@@ -64,9 +71,75 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
 | 
					    public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler)
 | 
				
			||||||
            throws OllamaBaseException, IOException, InterruptedException {
 | 
					            throws OllamaBaseException, IOException, InterruptedException {
 | 
				
			||||||
        streamObserver = new OllamaChatStreamObserver(streamHandler);
 | 
					        streamObserver = new OllamaChatStreamObserver(streamHandler);
 | 
				
			||||||
        return super.callSync(body);
 | 
					        return callSync(body);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public OllamaChatResult callSync(OllamaChatRequest body) throws OllamaBaseException, IOException, InterruptedException {
 | 
				
			||||||
 | 
					        // Create Request
 | 
				
			||||||
 | 
					        HttpClient httpClient = HttpClient.newHttpClient();
 | 
				
			||||||
 | 
					        URI uri = URI.create(getHost() + getEndpointSuffix());
 | 
				
			||||||
 | 
					        HttpRequest.Builder requestBuilder =
 | 
				
			||||||
 | 
					                getRequestBuilderDefault(uri)
 | 
				
			||||||
 | 
					                        .POST(
 | 
				
			||||||
 | 
					                                body.getBodyPublisher());
 | 
				
			||||||
 | 
					        HttpRequest request = requestBuilder.build();
 | 
				
			||||||
 | 
					        if (isVerbose()) LOG.info("Asking model: " + body.toString());
 | 
				
			||||||
 | 
					        HttpResponse<InputStream> response =
 | 
				
			||||||
 | 
					                httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        int statusCode = response.statusCode();
 | 
				
			||||||
 | 
					        InputStream responseBodyStream = response.body();
 | 
				
			||||||
 | 
					        StringBuilder responseBuffer = new StringBuilder();
 | 
				
			||||||
 | 
					        OllamaChatResponseModel ollamaChatResponseModel = null;
 | 
				
			||||||
 | 
					        List<OllamaChatToolCalls> wantedToolsForStream = null;
 | 
				
			||||||
 | 
					        try (BufferedReader reader =
 | 
				
			||||||
 | 
					                     new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            String line;
 | 
				
			||||||
 | 
					            while ((line = reader.readLine()) != null) {
 | 
				
			||||||
 | 
					                if (statusCode == 404) {
 | 
				
			||||||
 | 
					                    LOG.warn("Status code: 404 (Not Found)");
 | 
				
			||||||
 | 
					                    OllamaErrorResponse ollamaResponseModel =
 | 
				
			||||||
 | 
					                            Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
 | 
				
			||||||
 | 
					                    responseBuffer.append(ollamaResponseModel.getError());
 | 
				
			||||||
 | 
					                } else if (statusCode == 401) {
 | 
				
			||||||
 | 
					                    LOG.warn("Status code: 401 (Unauthorized)");
 | 
				
			||||||
 | 
					                    OllamaErrorResponse ollamaResponseModel =
 | 
				
			||||||
 | 
					                            Utils.getObjectMapper()
 | 
				
			||||||
 | 
					                                    .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
 | 
				
			||||||
 | 
					                    responseBuffer.append(ollamaResponseModel.getError());
 | 
				
			||||||
 | 
					                } else if (statusCode == 400) {
 | 
				
			||||||
 | 
					                    LOG.warn("Status code: 400 (Bad Request)");
 | 
				
			||||||
 | 
					                    OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
 | 
				
			||||||
 | 
					                            OllamaErrorResponse.class);
 | 
				
			||||||
 | 
					                    responseBuffer.append(ollamaResponseModel.getError());
 | 
				
			||||||
 | 
					                } else {
 | 
				
			||||||
 | 
					                    boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
 | 
				
			||||||
 | 
					                        ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
 | 
				
			||||||
 | 
					                    if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){
 | 
				
			||||||
 | 
					                        wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                    if (finished && body.stream) {
 | 
				
			||||||
 | 
					                        ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
 | 
				
			||||||
 | 
					                        break;
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (statusCode != 200) {
 | 
				
			||||||
 | 
					            LOG.error("Status code " + statusCode);
 | 
				
			||||||
 | 
					            throw new OllamaBaseException(responseBuffer.toString());
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            if(wantedToolsForStream != null) {
 | 
				
			||||||
 | 
					                ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            OllamaChatResult ollamaResult =
 | 
				
			||||||
 | 
					                    new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
 | 
				
			||||||
 | 
					            if (isVerbose()) LOG.info("Model response: " + ollamaResult);
 | 
				
			||||||
 | 
					            return ollamaResult;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,6 +6,7 @@ import io.github.ollama4j.models.response.OllamaErrorResponse;
 | 
				
			|||||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
					import io.github.ollama4j.models.response.OllamaResult;
 | 
				
			||||||
import io.github.ollama4j.utils.OllamaRequestBody;
 | 
					import io.github.ollama4j.utils.OllamaRequestBody;
 | 
				
			||||||
import io.github.ollama4j.utils.Utils;
 | 
					import io.github.ollama4j.utils.Utils;
 | 
				
			||||||
 | 
					import lombok.Getter;
 | 
				
			||||||
import org.slf4j.Logger;
 | 
					import org.slf4j.Logger;
 | 
				
			||||||
import org.slf4j.LoggerFactory;
 | 
					import org.slf4j.LoggerFactory;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -24,14 +25,15 @@ import java.util.Base64;
 | 
				
			|||||||
/**
 | 
					/**
 | 
				
			||||||
 * Abstract helperclass to call the ollama api server.
 | 
					 * Abstract helperclass to call the ollama api server.
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
 | 
					@Getter
 | 
				
			||||||
public abstract class OllamaEndpointCaller {
 | 
					public abstract class OllamaEndpointCaller {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
 | 
					    private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private String host;
 | 
					    private final String host;
 | 
				
			||||||
    private BasicAuth basicAuth;
 | 
					    private final BasicAuth basicAuth;
 | 
				
			||||||
    private long requestTimeoutSeconds;
 | 
					    private final long requestTimeoutSeconds;
 | 
				
			||||||
    private boolean verbose;
 | 
					    private final boolean verbose;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
 | 
					    public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
 | 
				
			||||||
        this.host = host;
 | 
					        this.host = host;
 | 
				
			||||||
@@ -45,80 +47,13 @@ public abstract class OllamaEndpointCaller {
 | 
				
			|||||||
    protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
 | 
					    protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					 | 
				
			||||||
     * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
 | 
					 | 
				
			||||||
     *
 | 
					 | 
				
			||||||
     * @param body POST body payload
 | 
					 | 
				
			||||||
     * @return result answer given by the assistant
 | 
					 | 
				
			||||||
     * @throws OllamaBaseException  any response code than 200 has been returned
 | 
					 | 
				
			||||||
     * @throws IOException          in case the responseStream can not be read
 | 
					 | 
				
			||||||
     * @throws InterruptedException in case the server is not reachable or network issues happen
 | 
					 | 
				
			||||||
     */
 | 
					 | 
				
			||||||
    public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
 | 
					 | 
				
			||||||
        // Create Request
 | 
					 | 
				
			||||||
        long startTime = System.currentTimeMillis();
 | 
					 | 
				
			||||||
        HttpClient httpClient = HttpClient.newHttpClient();
 | 
					 | 
				
			||||||
        URI uri = URI.create(this.host + getEndpointSuffix());
 | 
					 | 
				
			||||||
        HttpRequest.Builder requestBuilder =
 | 
					 | 
				
			||||||
                getRequestBuilderDefault(uri)
 | 
					 | 
				
			||||||
                        .POST(
 | 
					 | 
				
			||||||
                                body.getBodyPublisher());
 | 
					 | 
				
			||||||
        HttpRequest request = requestBuilder.build();
 | 
					 | 
				
			||||||
        if (this.verbose) LOG.info("Asking model: " + body.toString());
 | 
					 | 
				
			||||||
        HttpResponse<InputStream> response =
 | 
					 | 
				
			||||||
                httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        int statusCode = response.statusCode();
 | 
					 | 
				
			||||||
        InputStream responseBodyStream = response.body();
 | 
					 | 
				
			||||||
        StringBuilder responseBuffer = new StringBuilder();
 | 
					 | 
				
			||||||
        try (BufferedReader reader =
 | 
					 | 
				
			||||||
                     new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
 | 
					 | 
				
			||||||
            String line;
 | 
					 | 
				
			||||||
            while ((line = reader.readLine()) != null) {
 | 
					 | 
				
			||||||
                if (statusCode == 404) {
 | 
					 | 
				
			||||||
                    LOG.warn("Status code: 404 (Not Found)");
 | 
					 | 
				
			||||||
                    OllamaErrorResponse ollamaResponseModel =
 | 
					 | 
				
			||||||
                            Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
 | 
					 | 
				
			||||||
                    responseBuffer.append(ollamaResponseModel.getError());
 | 
					 | 
				
			||||||
                } else if (statusCode == 401) {
 | 
					 | 
				
			||||||
                    LOG.warn("Status code: 401 (Unauthorized)");
 | 
					 | 
				
			||||||
                    OllamaErrorResponse ollamaResponseModel =
 | 
					 | 
				
			||||||
                            Utils.getObjectMapper()
 | 
					 | 
				
			||||||
                                    .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
 | 
					 | 
				
			||||||
                    responseBuffer.append(ollamaResponseModel.getError());
 | 
					 | 
				
			||||||
                } else if (statusCode == 400) {
 | 
					 | 
				
			||||||
                    LOG.warn("Status code: 400 (Bad Request)");
 | 
					 | 
				
			||||||
                    OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
 | 
					 | 
				
			||||||
                            OllamaErrorResponse.class);
 | 
					 | 
				
			||||||
                    responseBuffer.append(ollamaResponseModel.getError());
 | 
					 | 
				
			||||||
                } else {
 | 
					 | 
				
			||||||
                    boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
 | 
					 | 
				
			||||||
                    if (finished) {
 | 
					 | 
				
			||||||
                        break;
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (statusCode != 200) {
 | 
					 | 
				
			||||||
            LOG.error("Status code " + statusCode);
 | 
					 | 
				
			||||||
            throw new OllamaBaseException(responseBuffer.toString());
 | 
					 | 
				
			||||||
        } else {
 | 
					 | 
				
			||||||
            long endTime = System.currentTimeMillis();
 | 
					 | 
				
			||||||
            OllamaResult ollamaResult =
 | 
					 | 
				
			||||||
                    new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
 | 
					 | 
				
			||||||
            if (verbose) LOG.info("Model response: " + ollamaResult);
 | 
					 | 
				
			||||||
            return ollamaResult;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Get default request builder.
 | 
					     * Get default request builder.
 | 
				
			||||||
     *
 | 
					     *
 | 
				
			||||||
     * @param uri URI to get a HttpRequest.Builder
 | 
					     * @param uri URI to get a HttpRequest.Builder
 | 
				
			||||||
     * @return HttpRequest.Builder
 | 
					     * @return HttpRequest.Builder
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
 | 
					    protected HttpRequest.Builder getRequestBuilderDefault(URI uri) {
 | 
				
			||||||
        HttpRequest.Builder requestBuilder =
 | 
					        HttpRequest.Builder requestBuilder =
 | 
				
			||||||
                HttpRequest.newBuilder(uri)
 | 
					                HttpRequest.newBuilder(uri)
 | 
				
			||||||
                        .header("Content-Type", "application/json")
 | 
					                        .header("Content-Type", "application/json")
 | 
				
			||||||
@@ -134,7 +69,7 @@ public abstract class OllamaEndpointCaller {
 | 
				
			|||||||
     *
 | 
					     *
 | 
				
			||||||
     * @return basic authentication header value (encoded credentials)
 | 
					     * @return basic authentication header value (encoded credentials)
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    private String getBasicAuthHeaderValue() {
 | 
					    protected String getBasicAuthHeaderValue() {
 | 
				
			||||||
        String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
 | 
					        String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
 | 
				
			||||||
        return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
 | 
					        return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -144,7 +79,7 @@ public abstract class OllamaEndpointCaller {
 | 
				
			|||||||
     *
 | 
					     *
 | 
				
			||||||
     * @return true when Basic Auth credentials set
 | 
					     * @return true when Basic Auth credentials set
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    private boolean isBasicAuthCredentialsSet() {
 | 
					    protected boolean isBasicAuthCredentialsSet() {
 | 
				
			||||||
        return this.basicAuth != null;
 | 
					        return this.basicAuth != null;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ package io.github.ollama4j.models.request;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
					import com.fasterxml.jackson.core.JsonProcessingException;
 | 
				
			||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
					import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
				
			||||||
 | 
					import io.github.ollama4j.models.response.OllamaErrorResponse;
 | 
				
			||||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
					import io.github.ollama4j.models.response.OllamaResult;
 | 
				
			||||||
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
 | 
					import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
 | 
				
			||||||
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
 | 
					import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
 | 
				
			||||||
@@ -11,7 +12,15 @@ import io.github.ollama4j.utils.Utils;
 | 
				
			|||||||
import org.slf4j.Logger;
 | 
					import org.slf4j.Logger;
 | 
				
			||||||
import org.slf4j.LoggerFactory;
 | 
					import org.slf4j.LoggerFactory;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.io.BufferedReader;
 | 
				
			||||||
import java.io.IOException;
 | 
					import java.io.IOException;
 | 
				
			||||||
 | 
					import java.io.InputStream;
 | 
				
			||||||
 | 
					import java.io.InputStreamReader;
 | 
				
			||||||
 | 
					import java.net.URI;
 | 
				
			||||||
 | 
					import java.net.http.HttpClient;
 | 
				
			||||||
 | 
					import java.net.http.HttpRequest;
 | 
				
			||||||
 | 
					import java.net.http.HttpResponse;
 | 
				
			||||||
 | 
					import java.nio.charset.StandardCharsets;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
 | 
					public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -46,6 +55,73 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
 | 
				
			|||||||
    public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
 | 
					    public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
 | 
				
			||||||
            throws OllamaBaseException, IOException, InterruptedException {
 | 
					            throws OllamaBaseException, IOException, InterruptedException {
 | 
				
			||||||
        streamObserver = new OllamaGenerateStreamObserver(streamHandler);
 | 
					        streamObserver = new OllamaGenerateStreamObserver(streamHandler);
 | 
				
			||||||
        return super.callSync(body);
 | 
					        return callSync(body);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param body POST body payload
 | 
				
			||||||
 | 
					     * @return result answer given by the assistant
 | 
				
			||||||
 | 
					     * @throws OllamaBaseException  any response code than 200 has been returned
 | 
				
			||||||
 | 
					     * @throws IOException          in case the responseStream can not be read
 | 
				
			||||||
 | 
					     * @throws InterruptedException in case the server is not reachable or network issues happen
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
 | 
				
			||||||
 | 
					        // Create Request
 | 
				
			||||||
 | 
					        long startTime = System.currentTimeMillis();
 | 
				
			||||||
 | 
					        HttpClient httpClient = HttpClient.newHttpClient();
 | 
				
			||||||
 | 
					        URI uri = URI.create(getHost() + getEndpointSuffix());
 | 
				
			||||||
 | 
					        HttpRequest.Builder requestBuilder =
 | 
				
			||||||
 | 
					                getRequestBuilderDefault(uri)
 | 
				
			||||||
 | 
					                        .POST(
 | 
				
			||||||
 | 
					                                body.getBodyPublisher());
 | 
				
			||||||
 | 
					        HttpRequest request = requestBuilder.build();
 | 
				
			||||||
 | 
					        if (isVerbose()) LOG.info("Asking model: " + body.toString());
 | 
				
			||||||
 | 
					        HttpResponse<InputStream> response =
 | 
				
			||||||
 | 
					                httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        int statusCode = response.statusCode();
 | 
				
			||||||
 | 
					        InputStream responseBodyStream = response.body();
 | 
				
			||||||
 | 
					        StringBuilder responseBuffer = new StringBuilder();
 | 
				
			||||||
 | 
					        try (BufferedReader reader =
 | 
				
			||||||
 | 
					                     new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
 | 
				
			||||||
 | 
					            String line;
 | 
				
			||||||
 | 
					            while ((line = reader.readLine()) != null) {
 | 
				
			||||||
 | 
					                if (statusCode == 404) {
 | 
				
			||||||
 | 
					                    LOG.warn("Status code: 404 (Not Found)");
 | 
				
			||||||
 | 
					                    OllamaErrorResponse ollamaResponseModel =
 | 
				
			||||||
 | 
					                            Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
 | 
				
			||||||
 | 
					                    responseBuffer.append(ollamaResponseModel.getError());
 | 
				
			||||||
 | 
					                } else if (statusCode == 401) {
 | 
				
			||||||
 | 
					                    LOG.warn("Status code: 401 (Unauthorized)");
 | 
				
			||||||
 | 
					                    OllamaErrorResponse ollamaResponseModel =
 | 
				
			||||||
 | 
					                            Utils.getObjectMapper()
 | 
				
			||||||
 | 
					                                    .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
 | 
				
			||||||
 | 
					                    responseBuffer.append(ollamaResponseModel.getError());
 | 
				
			||||||
 | 
					                } else if (statusCode == 400) {
 | 
				
			||||||
 | 
					                    LOG.warn("Status code: 400 (Bad Request)");
 | 
				
			||||||
 | 
					                    OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
 | 
				
			||||||
 | 
					                            OllamaErrorResponse.class);
 | 
				
			||||||
 | 
					                    responseBuffer.append(ollamaResponseModel.getError());
 | 
				
			||||||
 | 
					                } else {
 | 
				
			||||||
 | 
					                    boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
 | 
				
			||||||
 | 
					                    if (finished) {
 | 
				
			||||||
 | 
					                        break;
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (statusCode != 200) {
 | 
				
			||||||
 | 
					            LOG.error("Status code " + statusCode);
 | 
				
			||||||
 | 
					            throw new OllamaBaseException(responseBuffer.toString());
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            long endTime = System.currentTimeMillis();
 | 
				
			||||||
 | 
					            OllamaResult ollamaResult =
 | 
				
			||||||
 | 
					                    new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
 | 
				
			||||||
 | 
					            if (isVerbose()) LOG.info("Model response: " + ollamaResult);
 | 
				
			||||||
 | 
					            return ollamaResult;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,16 @@
 | 
				
			|||||||
 | 
					package io.github.ollama4j.tools;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import lombok.AllArgsConstructor;
 | 
				
			||||||
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Data
 | 
				
			||||||
 | 
					@NoArgsConstructor
 | 
				
			||||||
 | 
					@AllArgsConstructor
 | 
				
			||||||
 | 
					public class OllamaToolCallsFunction
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					    private String name;
 | 
				
			||||||
 | 
					    private Map<String,Object> arguments;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,16 +1,22 @@
 | 
				
			|||||||
package io.github.ollama4j.tools;
 | 
					package io.github.ollama4j.tools;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.Collection;
 | 
				
			||||||
import java.util.HashMap;
 | 
					import java.util.HashMap;
 | 
				
			||||||
import java.util.Map;
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public class ToolRegistry {
 | 
					public class ToolRegistry {
 | 
				
			||||||
    private final Map<String, ToolFunction> functionMap = new HashMap<>();
 | 
					    private final Map<String, Tools.ToolSpecification> tools = new HashMap<>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public ToolFunction getFunction(String name) {
 | 
					    public ToolFunction getToolFunction(String name) {
 | 
				
			||||||
        return functionMap.get(name);
 | 
					        final Tools.ToolSpecification toolSpecification = tools.get(name);
 | 
				
			||||||
 | 
					        return toolSpecification !=null ? toolSpecification.getToolFunction() : null ;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public void addFunction(String name, ToolFunction function) {
 | 
					    public void addTool (String name, Tools.ToolSpecification specification) {
 | 
				
			||||||
        functionMap.put(name, function);
 | 
					        tools.put(name, specification);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public Collection<Tools.ToolSpecification> getRegisteredSpecs(){
 | 
				
			||||||
 | 
					        return tools.values();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,8 +6,10 @@ import com.fasterxml.jackson.annotation.JsonInclude;
 | 
				
			|||||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
					import com.fasterxml.jackson.annotation.JsonProperty;
 | 
				
			||||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
					import com.fasterxml.jackson.core.JsonProcessingException;
 | 
				
			||||||
import io.github.ollama4j.utils.Utils;
 | 
					import io.github.ollama4j.utils.Utils;
 | 
				
			||||||
 | 
					import lombok.AllArgsConstructor;
 | 
				
			||||||
import lombok.Builder;
 | 
					import lombok.Builder;
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.util.ArrayList;
 | 
					import java.util.ArrayList;
 | 
				
			||||||
import java.util.HashMap;
 | 
					import java.util.HashMap;
 | 
				
			||||||
@@ -20,17 +22,23 @@ public class Tools {
 | 
				
			|||||||
    public static class ToolSpecification {
 | 
					    public static class ToolSpecification {
 | 
				
			||||||
        private String functionName;
 | 
					        private String functionName;
 | 
				
			||||||
        private String functionDescription;
 | 
					        private String functionDescription;
 | 
				
			||||||
        private Map<String, PromptFuncDefinition.Property> properties;
 | 
					        private PromptFuncDefinition toolPrompt;
 | 
				
			||||||
        private ToolFunction toolDefinition;
 | 
					        private ToolFunction toolFunction;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Data
 | 
					    @Data
 | 
				
			||||||
    @JsonIgnoreProperties(ignoreUnknown = true)
 | 
					    @JsonIgnoreProperties(ignoreUnknown = true)
 | 
				
			||||||
 | 
					    @Builder
 | 
				
			||||||
 | 
					    @NoArgsConstructor
 | 
				
			||||||
 | 
					    @AllArgsConstructor
 | 
				
			||||||
    public static class PromptFuncDefinition {
 | 
					    public static class PromptFuncDefinition {
 | 
				
			||||||
        private String type;
 | 
					        private String type;
 | 
				
			||||||
        private PromptFuncSpec function;
 | 
					        private PromptFuncSpec function;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Data
 | 
					        @Data
 | 
				
			||||||
 | 
					        @Builder
 | 
				
			||||||
 | 
					        @NoArgsConstructor
 | 
				
			||||||
 | 
					        @AllArgsConstructor
 | 
				
			||||||
        public static class PromptFuncSpec {
 | 
					        public static class PromptFuncSpec {
 | 
				
			||||||
            private String name;
 | 
					            private String name;
 | 
				
			||||||
            private String description;
 | 
					            private String description;
 | 
				
			||||||
@@ -38,6 +46,9 @@ public class Tools {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @Data
 | 
					        @Data
 | 
				
			||||||
 | 
					        @Builder
 | 
				
			||||||
 | 
					        @NoArgsConstructor
 | 
				
			||||||
 | 
					        @AllArgsConstructor
 | 
				
			||||||
        public static class Parameters {
 | 
					        public static class Parameters {
 | 
				
			||||||
            private String type;
 | 
					            private String type;
 | 
				
			||||||
            private Map<String, Property> properties;
 | 
					            private Map<String, Property> properties;
 | 
				
			||||||
@@ -46,6 +57,8 @@ public class Tools {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        @Data
 | 
					        @Data
 | 
				
			||||||
        @Builder
 | 
					        @Builder
 | 
				
			||||||
 | 
					        @NoArgsConstructor
 | 
				
			||||||
 | 
					        @AllArgsConstructor
 | 
				
			||||||
        public static class Property {
 | 
					        public static class Property {
 | 
				
			||||||
            private String type;
 | 
					            private String type;
 | 
				
			||||||
            private String description;
 | 
					            private String description;
 | 
				
			||||||
@@ -94,10 +107,10 @@ public class Tools {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
 | 
					            PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
 | 
				
			||||||
            parameters.setType("object");
 | 
					            parameters.setType("object");
 | 
				
			||||||
            parameters.setProperties(spec.getProperties());
 | 
					            parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            List<String> requiredValues = new ArrayList<>();
 | 
					            List<String> requiredValues = new ArrayList<>();
 | 
				
			||||||
            for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) {
 | 
					            for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) {
 | 
				
			||||||
                if (p.getValue().isRequired()) {
 | 
					                if (p.getValue().isRequired()) {
 | 
				
			||||||
                    requiredValues.add(p.getKey());
 | 
					                    requiredValues.add(p.getKey());
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,14 +2,13 @@ package io.github.ollama4j.integrationtests;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import io.github.ollama4j.OllamaAPI;
 | 
					import io.github.ollama4j.OllamaAPI;
 | 
				
			||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
					import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
				
			||||||
 | 
					import io.github.ollama4j.models.chat.*;
 | 
				
			||||||
import io.github.ollama4j.models.response.ModelDetail;
 | 
					import io.github.ollama4j.models.response.ModelDetail;
 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
					import io.github.ollama4j.models.response.OllamaResult;
 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatResult;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
 | 
					import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
 | 
				
			||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
					import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.ToolFunction;
 | 
				
			||||||
 | 
					import io.github.ollama4j.tools.Tools;
 | 
				
			||||||
import io.github.ollama4j.utils.OptionsBuilder;
 | 
					import io.github.ollama4j.utils.OptionsBuilder;
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
import org.junit.jupiter.api.BeforeEach;
 | 
					import org.junit.jupiter.api.BeforeEach;
 | 
				
			||||||
@@ -24,9 +23,7 @@ import java.io.InputStream;
 | 
				
			|||||||
import java.net.ConnectException;
 | 
					import java.net.ConnectException;
 | 
				
			||||||
import java.net.URISyntaxException;
 | 
					import java.net.URISyntaxException;
 | 
				
			||||||
import java.net.http.HttpConnectTimeoutException;
 | 
					import java.net.http.HttpConnectTimeoutException;
 | 
				
			||||||
import java.util.List;
 | 
					import java.util.*;
 | 
				
			||||||
import java.util.Objects;
 | 
					 | 
				
			||||||
import java.util.Properties;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import static org.junit.jupiter.api.Assertions.*;
 | 
					import static org.junit.jupiter.api.Assertions.*;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -47,6 +44,7 @@ class TestRealAPIs {
 | 
				
			|||||||
        config = new Config();
 | 
					        config = new Config();
 | 
				
			||||||
        ollamaAPI = new OllamaAPI(config.getOllamaURL());
 | 
					        ollamaAPI = new OllamaAPI(config.getOllamaURL());
 | 
				
			||||||
        ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
 | 
					        ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
 | 
				
			||||||
 | 
					        ollamaAPI.setVerbose(true);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
@@ -196,7 +194,9 @@ class TestRealAPIs {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
					            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
            assertNotNull(chatResult);
 | 
					            assertNotNull(chatResult);
 | 
				
			||||||
            assertFalse(chatResult.getResponse().isBlank());
 | 
					            assertNotNull(chatResult.getResponseModel());
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
				
			||||||
 | 
					            assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
 | 
				
			||||||
            assertEquals(4, chatResult.getChatHistory().size());
 | 
					            assertEquals(4, chatResult.getChatHistory().size());
 | 
				
			||||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
					        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
				
			||||||
            fail(e);
 | 
					            fail(e);
 | 
				
			||||||
@@ -217,14 +217,134 @@ class TestRealAPIs {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
					            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
            assertNotNull(chatResult);
 | 
					            assertNotNull(chatResult);
 | 
				
			||||||
            assertFalse(chatResult.getResponse().isBlank());
 | 
					            assertNotNull(chatResult.getResponseModel());
 | 
				
			||||||
            assertTrue(chatResult.getResponse().startsWith("NI"));
 | 
					            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
				
			||||||
 | 
					            assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
 | 
				
			||||||
 | 
					            assertTrue(chatResult.getResponseModel().getMessage().getContent().startsWith("NI"));
 | 
				
			||||||
            assertEquals(3, chatResult.getChatHistory().size());
 | 
					            assertEquals(3, chatResult.getChatHistory().size());
 | 
				
			||||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
					        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
				
			||||||
            fail(e);
 | 
					            fail(e);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    @Order(3)
 | 
				
			||||||
 | 
					    void testChatWithTools() {
 | 
				
			||||||
 | 
					        testEndpointReachability();
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            ollamaAPI.setVerbose(true);
 | 
				
			||||||
 | 
					            OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
 | 
				
			||||||
 | 
					                    .functionName("get-employee-details")
 | 
				
			||||||
 | 
					                    .functionDescription("Get employee details from the database")
 | 
				
			||||||
 | 
					                    .toolPrompt(
 | 
				
			||||||
 | 
					                            Tools.PromptFuncDefinition.builder().type("function").function(
 | 
				
			||||||
 | 
					                                    Tools.PromptFuncDefinition.PromptFuncSpec.builder()
 | 
				
			||||||
 | 
					                                            .name("get-employee-details")
 | 
				
			||||||
 | 
					                                            .description("Get employee details from the database")
 | 
				
			||||||
 | 
					                                            .parameters(
 | 
				
			||||||
 | 
					                                                    Tools.PromptFuncDefinition.Parameters.builder()
 | 
				
			||||||
 | 
					                                                            .type("object")
 | 
				
			||||||
 | 
					                                                            .properties(
 | 
				
			||||||
 | 
					                                                            new Tools.PropsBuilder()
 | 
				
			||||||
 | 
					                                                                    .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
 | 
				
			||||||
 | 
					                                                                    .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
 | 
				
			||||||
 | 
					                                                                    .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
 | 
				
			||||||
 | 
					                                                                    .build()
 | 
				
			||||||
 | 
					                                                    )
 | 
				
			||||||
 | 
					                                                            .required(List.of("employee-name"))
 | 
				
			||||||
 | 
					                                                            .build()
 | 
				
			||||||
 | 
					                                            ).build()
 | 
				
			||||||
 | 
					                            ).build()
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    .toolFunction(new DBQueryFunction())
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            ollamaAPI.registerTool(databaseQueryToolSpecification);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OllamaChatRequest requestModel = builder
 | 
				
			||||||
 | 
					                    .withMessage(OllamaChatMessageRole.USER,
 | 
				
			||||||
 | 
					                            "Give me the ID of the employee named 'Rahul Kumar'?")
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
 | 
					            assertNotNull(chatResult);
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel());
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
				
			||||||
 | 
					            assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
 | 
				
			||||||
 | 
					            List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
 | 
				
			||||||
 | 
					            assertEquals(1, toolCalls.size());
 | 
				
			||||||
 | 
					            assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName());
 | 
				
			||||||
 | 
					            assertEquals(1, toolCalls.get(0).getFunction().getArguments().size());
 | 
				
			||||||
 | 
					            Object employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name");
 | 
				
			||||||
 | 
					            assertNotNull(employeeName);
 | 
				
			||||||
 | 
					            assertEquals("Rahul Kumar",employeeName);
 | 
				
			||||||
 | 
					            assertTrue(chatResult.getChatHistory().size()>2);
 | 
				
			||||||
 | 
					            List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
 | 
				
			||||||
 | 
					            assertNull(finalToolCalls);
 | 
				
			||||||
 | 
					        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
				
			||||||
 | 
					            fail(e);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    @Order(3)
 | 
				
			||||||
 | 
					    void testChatWithToolsAndStream() {
 | 
				
			||||||
 | 
					        testEndpointReachability();
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
 | 
				
			||||||
 | 
					            final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
 | 
				
			||||||
 | 
					                    .functionName("get-employee-details")
 | 
				
			||||||
 | 
					                    .functionDescription("Get employee details from the database")
 | 
				
			||||||
 | 
					                    .toolPrompt(
 | 
				
			||||||
 | 
					                            Tools.PromptFuncDefinition.builder().type("function").function(
 | 
				
			||||||
 | 
					                                    Tools.PromptFuncDefinition.PromptFuncSpec.builder()
 | 
				
			||||||
 | 
					                                            .name("get-employee-details")
 | 
				
			||||||
 | 
					                                            .description("Get employee details from the database")
 | 
				
			||||||
 | 
					                                            .parameters(
 | 
				
			||||||
 | 
					                                                    Tools.PromptFuncDefinition.Parameters.builder()
 | 
				
			||||||
 | 
					                                                            .type("object")
 | 
				
			||||||
 | 
					                                                            .properties(
 | 
				
			||||||
 | 
					                                                                    new Tools.PropsBuilder()
 | 
				
			||||||
 | 
					                                                                            .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
 | 
				
			||||||
 | 
					                                                                            .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
 | 
				
			||||||
 | 
					                                                                            .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
 | 
				
			||||||
 | 
					                                                                            .build()
 | 
				
			||||||
 | 
					                                                            )
 | 
				
			||||||
 | 
					                                                            .required(List.of("employee-name"))
 | 
				
			||||||
 | 
					                                                            .build()
 | 
				
			||||||
 | 
					                                            ).build()
 | 
				
			||||||
 | 
					                            ).build()
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    .toolFunction(new DBQueryFunction())
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            ollamaAPI.registerTool(databaseQueryToolSpecification);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OllamaChatRequest requestModel = builder
 | 
				
			||||||
 | 
					                    .withMessage(OllamaChatMessageRole.USER,
 | 
				
			||||||
 | 
					                            "Give me the ID of the employee named 'Rahul Kumar'?")
 | 
				
			||||||
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            StringBuffer sb = new StringBuffer();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
 | 
				
			||||||
 | 
					                LOG.info(s);
 | 
				
			||||||
 | 
					                String substring = s.substring(sb.toString().length());
 | 
				
			||||||
 | 
					                LOG.info(substring);
 | 
				
			||||||
 | 
					                sb.append(substring);
 | 
				
			||||||
 | 
					            });
 | 
				
			||||||
 | 
					            assertNotNull(chatResult);
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel());
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel().getMessage().getContent());
 | 
				
			||||||
 | 
					            assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
 | 
				
			||||||
 | 
					        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
				
			||||||
 | 
					            fail(e);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    @Order(3)
 | 
					    @Order(3)
 | 
				
			||||||
    void testChatWithStream() {
 | 
					    void testChatWithStream() {
 | 
				
			||||||
@@ -244,7 +364,10 @@ class TestRealAPIs {
 | 
				
			|||||||
                sb.append(substring);
 | 
					                sb.append(substring);
 | 
				
			||||||
            });
 | 
					            });
 | 
				
			||||||
            assertNotNull(chatResult);
 | 
					            assertNotNull(chatResult);
 | 
				
			||||||
            assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
 | 
					            assertNotNull(chatResult.getResponseModel());
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel().getMessage());
 | 
				
			||||||
 | 
					            assertNotNull(chatResult.getResponseModel().getMessage().getContent());
 | 
				
			||||||
 | 
					            assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
 | 
				
			||||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
					        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
				
			||||||
            fail(e);
 | 
					            fail(e);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@@ -258,12 +381,12 @@ class TestRealAPIs {
 | 
				
			|||||||
            OllamaChatRequestBuilder builder =
 | 
					            OllamaChatRequestBuilder builder =
 | 
				
			||||||
                    OllamaChatRequestBuilder.getInstance(config.getImageModel());
 | 
					                    OllamaChatRequestBuilder.getInstance(config.getImageModel());
 | 
				
			||||||
            OllamaChatRequest requestModel =
 | 
					            OllamaChatRequest requestModel =
 | 
				
			||||||
                    builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
 | 
					                    builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
 | 
				
			||||||
                            List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
 | 
					                            List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
					            OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
            assertNotNull(chatResult);
 | 
					            assertNotNull(chatResult);
 | 
				
			||||||
            assertNotNull(chatResult.getResponse());
 | 
					            assertNotNull(chatResult.getResponseModel());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            builder.reset();
 | 
					            builder.reset();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -273,7 +396,7 @@ class TestRealAPIs {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            chatResult = ollamaAPI.chat(requestModel);
 | 
					            chatResult = ollamaAPI.chat(requestModel);
 | 
				
			||||||
            assertNotNull(chatResult);
 | 
					            assertNotNull(chatResult);
 | 
				
			||||||
            assertNotNull(chatResult.getResponse());
 | 
					            assertNotNull(chatResult.getResponseModel());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
					        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
				
			||||||
@@ -287,7 +410,7 @@ class TestRealAPIs {
 | 
				
			|||||||
        testEndpointReachability();
 | 
					        testEndpointReachability();
 | 
				
			||||||
        try {
 | 
					        try {
 | 
				
			||||||
            OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
 | 
					            OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
 | 
				
			||||||
            OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
 | 
					            OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
 | 
				
			||||||
                            "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
 | 
					                            "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
 | 
				
			||||||
                    .build();
 | 
					                    .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -380,6 +503,14 @@ class TestRealAPIs {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DBQueryFunction implements ToolFunction {
 | 
				
			||||||
 | 
					    @Override
 | 
				
			||||||
 | 
					    public Object apply(Map<String, Object> arguments) {
 | 
				
			||||||
 | 
					        // perform DB operations here
 | 
				
			||||||
 | 
					        return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone"));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
class Config {
 | 
					class Config {
 | 
				
			||||||
    private String ollamaURL;
 | 
					    private String ollamaURL;
 | 
				
			||||||
@@ -404,4 +535,6 @@ class Config {
 | 
				
			|||||||
            throw new RuntimeException("Error loading properties", e);
 | 
					            throw new RuntimeException("Error loading properties", e);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
 | 
				
			|||||||
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
 | 
					import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.io.File;
 | 
					import java.io.File;
 | 
				
			||||||
 | 
					import java.util.Collections;
 | 
				
			||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
 | 
					import io.github.ollama4j.models.chat.OllamaChatRequest;
 | 
				
			||||||
@@ -42,7 +43,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testRequestWithMessageAndImage() {
 | 
					    public void testRequestWithMessageAndImage() {
 | 
				
			||||||
        OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
 | 
					        OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", Collections.emptyList(),
 | 
				
			||||||
                List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
 | 
					                List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
 | 
				
			||||||
        String jsonRequest = serialize(req);
 | 
					        String jsonRequest = serialize(req);
 | 
				
			||||||
        assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);
 | 
					        assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
ollama.url=http://localhost:11434
 | 
					ollama.url=http://localhost:11434
 | 
				
			||||||
ollama.model=qwen:0.5b
 | 
					ollama.model=llama3.2:1b
 | 
				
			||||||
ollama.model.image=llava
 | 
					ollama.model.image=llava:latest
 | 
				
			||||||
ollama.request-timeout-seconds=120
 | 
					ollama.request-timeout-seconds=120
 | 
				
			||||||
		Reference in New Issue
	
	Block a user