Merge pull request #82 from AgentSchmecker/feature/toolextension_for_chat_model

Enable chat API to use Tools
This commit is contained in:
Amith Koujalgi 2024-12-17 12:03:56 +05:30 committed by GitHub
commit 8b3417ecda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 587 additions and 132 deletions

View File

@ -33,7 +33,7 @@ public class Main {
// start conversation with model // start conversation with model
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
System.out.println("First answer: " + chatResult.getResponse()); System.out.println("First answer: " + chatResult.getResponseModel().getMessage().getContent());
// create next userQuestion // create next userQuestion
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is the second largest city?").build(); requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is the second largest city?").build();
@ -41,7 +41,7 @@ public class Main {
// "continue" conversation with model // "continue" conversation with model
chatResult = ollamaAPI.chat(requestModel); chatResult = ollamaAPI.chat(requestModel);
System.out.println("Second answer: " + chatResult.getResponse()); System.out.println("Second answer: " + chatResult.getResponseModel().getMessage().getContent());
System.out.println("Chat History: " + chatResult.getChatHistory()); System.out.println("Chat History: " + chatResult.getChatHistory());
} }
@ -205,7 +205,7 @@ public class Main {
// start conversation with model // start conversation with model
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
System.out.println(chatResult.getResponse()); System.out.println(chatResult.getResponseModel());
} }
} }
@ -244,7 +244,7 @@ public class Main {
new File("/path/to/image"))).build(); new File("/path/to/image"))).build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
System.out.println("First answer: " + chatResult.getResponse()); System.out.println("First answer: " + chatResult.getResponseModel());
builder.reset(); builder.reset();
@ -254,7 +254,7 @@ public class Main {
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
chatResult = ollamaAPI.chat(requestModel); chatResult = ollamaAPI.chat(requestModel);
System.out.println("Second answer: " + chatResult.getResponse()); System.out.println("Second answer: " + chatResult.getResponseModel());
} }
} }
``` ```

View File

@ -345,6 +345,125 @@ Rahul Kumar, Address: King St, Hyderabad, India, Phone: 9876543210}`
:::: ::::
### Using tools in Chat-API
Instead of using the specific `ollamaAPI.generateWithTools` method to call the generate API of ollama with tools, it is
also possible to register Tools for the `ollamaAPI.chat` methods. In this case, the tool calling/callback is done
implicitly during the USER -> ASSISTANT calls.
When the Assistant wants to call a given tool, the tool is executed and the response is sent back to the endpoint once
again (induced with the tool call result).
#### Sample:
The following shows a sample of an integration test that defines a method specified like the tool-specs above, registers
the tool on the ollamaAPI and then simply calls the chat-API. All intermediate tool calling is wrapped inside the api
call.
```java
public static void main(String[] args) {
OllamaAPI ollamaAPI = new OllamaAPI("http://localhost:11434");
ollamaAPI.setVerbose(true);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("llama3.2:1b");
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.toolPrompt(
Tools.PromptFuncDefinition.builder().type("function").function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get employee details from the database")
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
new Tools.PropsBuilder()
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
.build()
)
.required(List.of("employee-name"))
.build()
).build()
).build()
)
.toolFunction(new DBQueryFunction())
.build();
ollamaAPI.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
"Give me the ID of the employee named 'Rahul Kumar'?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
}
```
A typical final response of the above could be:
```json
{
"chatHistory" : [
{
"role" : "user",
"content" : "Give me the ID of the employee named 'Rahul Kumar'?",
"images" : null,
"tool_calls" : [ ]
}, {
"role" : "assistant",
"content" : "",
"images" : null,
"tool_calls" : [ {
"function" : {
"name" : "get-employee-details",
"arguments" : {
"employee-name" : "Rahul Kumar"
}
}
} ]
}, {
"role" : "tool",
"content" : "[TOOL_RESULTS]get-employee-details([employee-name]) : Employee Details {ID: b4bf186c-2ee1-44cc-8856-53b8b6a50f85, Name: Rahul Kumar, Address: null, Phone: null}[/TOOL_RESULTS]",
"images" : null,
"tool_calls" : null
}, {
"role" : "assistant",
"content" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
"images" : null,
"tool_calls" : null
} ],
"responseModel" : {
"model" : "llama3.2:1b",
"message" : {
"role" : "assistant",
"content" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
"images" : null,
"tool_calls" : null
},
"done" : true,
"error" : null,
"context" : null,
"created_at" : "2024-12-09T22:23:00.4940078Z",
"done_reason" : "stop",
"total_duration" : 2313709900,
"load_duration" : 14494700,
"prompt_eval_duration" : 772000000,
"eval_duration" : 1188000000,
"prompt_eval_count" : 166,
"eval_count" : 41
},
"response" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
"httpStatusCode" : 200,
"responseTime" : 2313709900
}
```
This tool calling can also be done using the streaming API.
### Potential Improvements ### Potential Improvements
Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool

View File

@ -59,6 +59,10 @@ public class OllamaAPI {
*/ */
@Setter @Setter
private boolean verbose = true; private boolean verbose = true;
@Setter
private int maxChatToolCallRetries = 3;
private BasicAuth basicAuth; private BasicAuth basicAuth;
private final ToolRegistry toolRegistry = new ToolRegistry(); private final ToolRegistry toolRegistry = new ToolRegistry();
@ -767,18 +771,44 @@ public class OllamaAPI {
*/ */
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
OllamaResult result; OllamaChatResult result;
// add all registered tools to Request
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
if (streamHandler != null) { if (streamHandler != null) {
request.setStream(true); request.setStream(true);
result = requestCaller.call(request, streamHandler); result = requestCaller.call(request, streamHandler);
} else { } else {
result = requestCaller.callSync(request); result = requestCaller.callSync(request);
} }
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
// check if toolCallIsWanted
List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
int toolCallTries = 0;
while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries){
for (OllamaChatToolCalls toolCall : toolCalls){
String toolName = toolCall.getFunction().getName();
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
Map<String, Object> arguments = toolCall.getFunction().getArguments();
Object res = toolFunction.apply(arguments);
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
}
if (streamHandler != null) {
result = requestCaller.call(request, streamHandler);
} else {
result = requestCaller.callSync(request);
}
toolCalls = result.getResponseModel().getMessage().getToolCalls();
toolCallTries++;
}
return result;
} }
public void registerTool(Tools.ToolSpecification toolSpecification) { public void registerTool(Tools.ToolSpecification toolSpecification) {
toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition()); toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
} }
/** /**
@ -871,7 +901,7 @@ public class OllamaAPI {
try { try {
String methodName = toolFunctionCallSpec.getName(); String methodName = toolFunctionCallSpec.getName();
Map<String, Object> arguments = toolFunctionCallSpec.getArguments(); Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
ToolFunction function = toolRegistry.getFunction(methodName); ToolFunction function = toolRegistry.getToolFunction(methodName);
if (verbose) { if (verbose) {
logger.debug("Invoking function {} with arguments {}", methodName, arguments); logger.debug("Invoking function {} with arguments {}", methodName, arguments);
} }

View File

@ -2,6 +2,7 @@ package io.github.ollama4j.models.chat;
import static io.github.ollama4j.utils.Utils.getObjectMapper; import static io.github.ollama4j.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
@ -32,6 +33,8 @@ public class OllamaChatMessage {
@NonNull @NonNull
private String content; private String content;
private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
@JsonSerialize(using = FileToBase64Serializer.class) @JsonSerialize(using = FileToBase64Serializer.class)
private List<byte[]> images; private List<byte[]> images;

View File

@ -3,6 +3,7 @@ package io.github.ollama4j.models.chat;
import java.util.List; import java.util.List;
import io.github.ollama4j.models.request.OllamaCommonRequest; import io.github.ollama4j.models.request.OllamaCommonRequest;
import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.OllamaRequestBody; import io.github.ollama4j.utils.OllamaRequestBody;
import lombok.Getter; import lombok.Getter;
@ -21,6 +22,8 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ
private List<OllamaChatMessage> messages; private List<OllamaChatMessage> messages;
private List<Tools.PromptFuncDefinition> tools;
public OllamaChatRequest() {} public OllamaChatRequest() {}
public OllamaChatRequest(String model, List<OllamaChatMessage> messages) { public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {

View File

@ -10,6 +10,7 @@ import java.io.IOException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.nio.file.Files; import java.nio.file.Files;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -38,7 +39,11 @@ public class OllamaChatRequestBuilder {
request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images) { public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){
return withMessage(role,content, Collections.emptyList());
}
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = images.stream().map(file -> { List<byte[]> binaryImages = images.stream().map(file -> {
@ -50,11 +55,11 @@ public class OllamaChatRequestBuilder {
} }
}).collect(Collectors.toList()); }).collect(Collectors.toList());
messages.add(new OllamaChatMessage(role, content, binaryImages)); messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
return this; return this;
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) { public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List<OllamaChatToolCalls> toolCalls, String... imageUrls) {
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = null; List<byte[]> binaryImages = null;
if (imageUrls.length > 0) { if (imageUrls.length > 0) {
@ -70,7 +75,7 @@ public class OllamaChatRequestBuilder {
} }
} }
messages.add(new OllamaChatMessage(role, content, binaryImages)); messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
return this; return this;
} }

View File

@ -2,28 +2,54 @@ package io.github.ollama4j.models.chat;
import java.util.List; import java.util.List;
import io.github.ollama4j.models.response.OllamaResult; import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.Getter;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
/** /**
* Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the * Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
* {@link OllamaChatMessageRole#ASSISTANT} role. * {@link OllamaChatMessageRole#ASSISTANT} role.
*/ */
public class OllamaChatResult extends OllamaResult { @Getter
public class OllamaChatResult {
private List<OllamaChatMessage> chatHistory; private List<OllamaChatMessage> chatHistory;
public OllamaChatResult(String response, long responseTime, int httpStatusCode, List<OllamaChatMessage> chatHistory) { private OllamaChatResponseModel responseModel;
super(response, responseTime, httpStatusCode);
public OllamaChatResult(OllamaChatResponseModel responseModel, List<OllamaChatMessage> chatHistory) {
this.chatHistory = chatHistory; this.chatHistory = chatHistory;
appendAnswerToChatHistory(response); this.responseModel = responseModel;
appendAnswerToChatHistory(responseModel);
} }
public List<OllamaChatMessage> getChatHistory() { private void appendAnswerToChatHistory(OllamaChatResponseModel response) {
return chatHistory; this.chatHistory.add(response.getMessage());
} }
private void appendAnswerToChatHistory(String answer) { @Override
OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer); public String toString() {
this.chatHistory.add(assistantMessage); try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
@Deprecated
public String getResponse(){
return responseModel != null ? responseModel.getMessage().getContent() : "";
}
@Deprecated
public int getHttpStatusCode(){
return 200;
}
@Deprecated
public long getResponseTime(){
return responseModel != null ? responseModel.getTotalDuration() : 0L;
} }
} }

View File

@ -0,0 +1,16 @@
package io.github.ollama4j.models.chat;
import io.github.ollama4j.tools.OllamaToolCallsFunction;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class OllamaChatToolCalls {
private OllamaToolCallsFunction function;
}

View File

@ -3,17 +3,24 @@ package io.github.ollama4j.models.request;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.chat.OllamaChatMessage; import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaErrorResponse;
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaStreamHandler;
import io.github.ollama4j.utils.OllamaRequestBody; import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.List;
/** /**
* Specialization class for requests * Specialization class for requests
@ -64,9 +71,75 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
} }
} }
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaChatStreamObserver(streamHandler); streamObserver = new OllamaChatStreamObserver(streamHandler);
return super.callSync(body); return callSync(body);
}
public OllamaChatResult callSync(OllamaChatRequest body) throws OllamaBaseException, IOException, InterruptedException {
// Create Request
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(getHost() + getEndpointSuffix());
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
if (isVerbose()) LOG.info("Asking model: " + body.toString());
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
OllamaChatResponseModel ollamaChatResponseModel = null;
List<OllamaChatToolCalls> wantedToolsForStream = null;
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)");
OllamaErrorResponse ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponse ollamaResponseModel =
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)");
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){
wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
}
if (finished && body.stream) {
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
break;
}
}
}
}
if (statusCode != 200) {
LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
if(wantedToolsForStream != null) {
ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
}
OllamaChatResult ollamaResult =
new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
return ollamaResult;
}
} }
} }

View File

@ -6,6 +6,7 @@ import io.github.ollama4j.models.response.OllamaErrorResponse;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.utils.OllamaRequestBody; import io.github.ollama4j.utils.OllamaRequestBody;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import lombok.Getter;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -24,14 +25,15 @@ import java.util.Base64;
/** /**
* Abstract helperclass to call the ollama api server. * Abstract helperclass to call the ollama api server.
*/ */
@Getter
public abstract class OllamaEndpointCaller { public abstract class OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
private String host; private final String host;
private BasicAuth basicAuth; private final BasicAuth basicAuth;
private long requestTimeoutSeconds; private final long requestTimeoutSeconds;
private boolean verbose; private final boolean verbose;
public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
this.host = host; this.host = host;
@ -45,80 +47,13 @@ public abstract class OllamaEndpointCaller {
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
/**
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
*
* @param body POST body payload
* @return result answer given by the assistant
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
// Create Request
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + getEndpointSuffix());
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
if (this.verbose) LOG.info("Asking model: " + body.toString());
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)");
OllamaErrorResponse ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponse ollamaResponseModel =
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)");
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
if (finished) {
break;
}
}
}
}
if (statusCode != 200) {
LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
if (verbose) LOG.info("Model response: " + ollamaResult);
return ollamaResult;
}
}
/** /**
* Get default request builder. * Get default request builder.
* *
* @param uri URI to get a HttpRequest.Builder * @param uri URI to get a HttpRequest.Builder
* @return HttpRequest.Builder * @return HttpRequest.Builder
*/ */
private HttpRequest.Builder getRequestBuilderDefault(URI uri) { protected HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder = HttpRequest.Builder requestBuilder =
HttpRequest.newBuilder(uri) HttpRequest.newBuilder(uri)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
@ -134,7 +69,7 @@ public abstract class OllamaEndpointCaller {
* *
* @return basic authentication header value (encoded credentials) * @return basic authentication header value (encoded credentials)
*/ */
private String getBasicAuthHeaderValue() { protected String getBasicAuthHeaderValue() {
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword(); String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
} }
@ -144,7 +79,7 @@ public abstract class OllamaEndpointCaller {
* *
* @return true when Basic Auth credentials set * @return true when Basic Auth credentials set
*/ */
private boolean isBasicAuthCredentialsSet() { protected boolean isBasicAuthCredentialsSet() {
return this.basicAuth != null; return this.basicAuth != null;
} }

View File

@ -2,6 +2,7 @@ package io.github.ollama4j.models.request;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.response.OllamaErrorResponse;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver; import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
@ -11,7 +12,15 @@ import io.github.ollama4j.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
@ -46,6 +55,73 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaGenerateStreamObserver(streamHandler); streamObserver = new OllamaGenerateStreamObserver(streamHandler);
return super.callSync(body); return callSync(body);
}
/**
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
*
* @param body POST body payload
* @return result answer given by the assistant
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
// Create Request
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(getHost() + getEndpointSuffix());
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
if (isVerbose()) LOG.info("Asking model: " + body.toString());
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)");
OllamaErrorResponse ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponse ollamaResponseModel =
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)");
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
if (finished) {
break;
}
}
}
}
if (statusCode != 200) {
LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
return ollamaResult;
}
} }
} }

View File

@ -0,0 +1,16 @@
package io.github.ollama4j.tools;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Map;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class OllamaToolCallsFunction
{
private String name;
private Map<String,Object> arguments;
}

View File

@ -1,16 +1,22 @@
package io.github.ollama4j.tools; package io.github.ollama4j.tools;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
public class ToolRegistry { public class ToolRegistry {
private final Map<String, ToolFunction> functionMap = new HashMap<>(); private final Map<String, Tools.ToolSpecification> tools = new HashMap<>();
public ToolFunction getFunction(String name) { public ToolFunction getToolFunction(String name) {
return functionMap.get(name); final Tools.ToolSpecification toolSpecification = tools.get(name);
return toolSpecification !=null ? toolSpecification.getToolFunction() : null ;
} }
public void addFunction(String name, ToolFunction function) { public void addTool (String name, Tools.ToolSpecification specification) {
functionMap.put(name, function); tools.put(name, specification);
}
public Collection<Tools.ToolSpecification> getRegisteredSpecs(){
return tools.values();
} }
} }

View File

@ -6,8 +6,10 @@ import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
@ -20,17 +22,23 @@ public class Tools {
public static class ToolSpecification { public static class ToolSpecification {
private String functionName; private String functionName;
private String functionDescription; private String functionDescription;
private Map<String, PromptFuncDefinition.Property> properties; private PromptFuncDefinition toolPrompt;
private ToolFunction toolDefinition; private ToolFunction toolFunction;
} }
@Data @Data
@JsonIgnoreProperties(ignoreUnknown = true) @JsonIgnoreProperties(ignoreUnknown = true)
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class PromptFuncDefinition { public static class PromptFuncDefinition {
private String type; private String type;
private PromptFuncSpec function; private PromptFuncSpec function;
@Data @Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class PromptFuncSpec { public static class PromptFuncSpec {
private String name; private String name;
private String description; private String description;
@ -38,6 +46,9 @@ public class Tools {
} }
@Data @Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class Parameters { public static class Parameters {
private String type; private String type;
private Map<String, Property> properties; private Map<String, Property> properties;
@ -46,6 +57,8 @@ public class Tools {
@Data @Data
@Builder @Builder
@NoArgsConstructor
@AllArgsConstructor
public static class Property { public static class Property {
private String type; private String type;
private String description; private String description;
@ -94,10 +107,10 @@ public class Tools {
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters(); PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
parameters.setType("object"); parameters.setType("object");
parameters.setProperties(spec.getProperties()); parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties());
List<String> requiredValues = new ArrayList<>(); List<String> requiredValues = new ArrayList<>();
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) { for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) {
if (p.getValue().isRequired()) { if (p.getValue().isRequired()) {
requiredValues.add(p.getKey()); requiredValues.add(p.getKey());
} }

View File

@ -2,14 +2,13 @@ package io.github.ollama4j.integrationtests;
import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.response.ModelDetail; import io.github.ollama4j.models.response.ModelDetail;
import io.github.ollama4j.models.chat.OllamaChatRequest;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.ollama4j.tools.ToolFunction;
import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.OptionsBuilder; import io.github.ollama4j.utils.OptionsBuilder;
import lombok.Data; import lombok.Data;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -24,9 +23,7 @@ import java.io.InputStream;
import java.net.ConnectException; import java.net.ConnectException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.net.http.HttpConnectTimeoutException; import java.net.http.HttpConnectTimeoutException;
import java.util.List; import java.util.*;
import java.util.Objects;
import java.util.Properties;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@ -47,6 +44,7 @@ class TestRealAPIs {
config = new Config(); config = new Config();
ollamaAPI = new OllamaAPI(config.getOllamaURL()); ollamaAPI = new OllamaAPI(config.getOllamaURL());
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds()); ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
ollamaAPI.setVerbose(true);
} }
@Test @Test
@ -196,7 +194,9 @@ class TestRealAPIs {
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
assertEquals(4, chatResult.getChatHistory().size()); assertEquals(4, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e); fail(e);
@ -217,14 +217,134 @@ class TestRealAPIs {
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank()); assertNotNull(chatResult.getResponseModel());
assertTrue(chatResult.getResponse().startsWith("NI")); assertNotNull(chatResult.getResponseModel().getMessage());
assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
assertTrue(chatResult.getResponseModel().getMessage().getContent().startsWith("NI"));
assertEquals(3, chatResult.getChatHistory().size()); assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e); fail(e);
} }
} }
@Test
@Order(3)
void testChatWithTools() {
testEndpointReachability();
try {
ollamaAPI.setVerbose(true);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.toolPrompt(
Tools.PromptFuncDefinition.builder().type("function").function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get employee details from the database")
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
new Tools.PropsBuilder()
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
.build()
)
.required(List.of("employee-name"))
.build()
).build()
).build()
)
.toolFunction(new DBQueryFunction())
.build();
ollamaAPI.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
"Give me the ID of the employee named 'Rahul Kumar'?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size());
assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName());
assertEquals(1, toolCalls.get(0).getFunction().getArguments().size());
Object employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name");
assertNotNull(employeeName);
assertEquals("Rahul Kumar",employeeName);
assertTrue(chatResult.getChatHistory().size()>2);
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
assertNull(finalToolCalls);
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testChatWithToolsAndStream() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.toolPrompt(
Tools.PromptFuncDefinition.builder().type("function").function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get employee details from the database")
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
new Tools.PropsBuilder()
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
.build()
)
.required(List.of("employee-name"))
.build()
).build()
).build()
)
.toolFunction(new DBQueryFunction())
.build();
ollamaAPI.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
"Give me the ID of the employee named 'Rahul Kumar'?")
.build();
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test @Test
@Order(3) @Order(3)
void testChatWithStream() { void testChatWithStream() {
@ -244,7 +364,10 @@ class TestRealAPIs {
sb.append(substring); sb.append(substring);
}); });
assertNotNull(chatResult); assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e); fail(e);
} }
@ -258,12 +381,12 @@ class TestRealAPIs {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder builder =
OllamaChatRequestBuilder.getInstance(config.getImageModel()); OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequest requestModel = OllamaChatRequest requestModel =
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponse()); assertNotNull(chatResult.getResponseModel());
builder.reset(); builder.reset();
@ -273,7 +396,7 @@ class TestRealAPIs {
chatResult = ollamaAPI.chat(requestModel); chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponse()); assertNotNull(chatResult.getResponseModel());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
@ -287,7 +410,7 @@ class TestRealAPIs {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel()); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build(); .build();
@ -380,6 +503,14 @@ class TestRealAPIs {
} }
} }
class DBQueryFunction implements ToolFunction {
@Override
public Object apply(Map<String, Object> arguments) {
// perform DB operations here
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone"));
}
}
@Data @Data
class Config { class Config {
private String ollamaURL; private String ollamaURL;
@ -404,4 +535,6 @@ class Config {
throw new RuntimeException("Error loading properties", e); throw new RuntimeException("Error loading properties", e);
} }
} }
} }

View File

@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrowsExactly; import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import java.io.File; import java.io.File;
import java.util.Collections;
import java.util.List; import java.util.List;
import io.github.ollama4j.models.chat.OllamaChatRequest; import io.github.ollama4j.models.chat.OllamaChatRequest;
@ -42,7 +43,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
@Test @Test
public void testRequestWithMessageAndImage() { public void testRequestWithMessageAndImage() {
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", Collections.emptyList(),
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build(); List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serialize(req); String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req); assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);

View File

@ -1,4 +1,4 @@
ollama.url=http://localhost:11434 ollama.url=http://localhost:11434
ollama.model=qwen:0.5b ollama.model=llama3.2:1b
ollama.model.image=llava ollama.model.image=llava:latest
ollama.request-timeout-seconds=120 ollama.request-timeout-seconds=120