From 25694a8bc9abf23b286f4a2474bbe0f164c06011 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Sat, 7 Dec 2024 00:29:09 +0100 Subject: [PATCH] extends ollamaChatResult to have full access to OllamaChatResult --- .../java/io/github/ollama4j/OllamaAPI.java | 6 +- .../models/chat/OllamaChatRequestBuilder.java | 5 ++ .../models/chat/OllamaChatResult.java | 28 +++++-- .../request/OllamaChatEndpointCaller.java | 74 ++++++++++++++++- .../models/request/OllamaEndpointCaller.java | 83 ++----------------- .../request/OllamaGenerateEndpointCaller.java | 78 ++++++++++++++++- .../models/response/OllamaResult.java | 6 +- .../integrationtests/TestRealAPIs.java | 37 ++++++--- .../jackson/TestChatRequestSerialization.java | 3 +- src/test/resources/test-config.properties | 4 +- 10 files changed, 218 insertions(+), 106 deletions(-) diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index f595090..d76ecd9 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -602,7 +602,7 @@ public class OllamaAPI { OllamaResult result = generate(model, prompt, raw, options, null); toolResult.setModelResult(result); - String toolsResponse = result.getContent(); + String toolsResponse = result.getResponse(); if (toolsResponse.contains("[TOOL_CALLS]")) { toolsResponse = toolsResponse.replace("[TOOL_CALLS]", ""); } @@ -767,7 +767,7 @@ public class OllamaAPI { */ public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); - OllamaResult result; + OllamaChatResult result; // add all registered tools to Request request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList())); @@ -779,7 +779,7 @@ public class OllamaAPI { result = requestCaller.callSync(request); } - return new OllamaChatResult(result.getContent(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); + return result; } public void registerTool(Tools.ToolSpecification toolSpecification) { diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java index 3546ba8..9094546 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java @@ -10,6 +10,7 @@ import java.io.IOException; import java.net.URISyntaxException; import java.nio.file.Files; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -38,6 +39,10 @@ public class OllamaChatRequestBuilder { request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); } + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){ + return withMessage(role,content, Collections.emptyList()); + } + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls,List images) { List messages = this.request.getMessages(); diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java index b9616f3..0ea1c00 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java @@ -2,28 +2,40 @@ package io.github.ollama4j.models.chat; import java.util.List; +import com.fasterxml.jackson.core.JsonProcessingException; import io.github.ollama4j.models.response.OllamaResult; +import lombok.Getter; + +import static io.github.ollama4j.utils.Utils.getObjectMapper; /** * Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the * {@link OllamaChatMessageRole#ASSISTANT} role. */ -public class OllamaChatResult extends OllamaResult { +@Getter +public class OllamaChatResult { + private List chatHistory; - public OllamaChatResult(String response, long responseTime, int httpStatusCode, List chatHistory) { - super(response, responseTime, httpStatusCode); + private OllamaChatResponseModel response; + + public OllamaChatResult(OllamaChatResponseModel response, List chatHistory) { this.chatHistory = chatHistory; + this.response = response; appendAnswerToChatHistory(response); } - public List getChatHistory() { - return chatHistory; + private void appendAnswerToChatHistory(OllamaChatResponseModel response) { + this.chatHistory.add(response.getMessage()); } - private void appendAnswerToChatHistory(String answer) { - OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer); - this.chatHistory.add(assistantMessage); + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } } } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java index e3d3fc1..a43d77c 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java @@ -4,6 +4,9 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.models.chat.OllamaChatMessage; +import io.github.ollama4j.models.chat.OllamaChatRequest; +import io.github.ollama4j.models.chat.OllamaChatResult; +import io.github.ollama4j.models.response.OllamaErrorResponse; import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.chat.OllamaChatResponseModel; import io.github.ollama4j.models.chat.OllamaChatStreamObserver; @@ -13,7 +16,15 @@ import io.github.ollama4j.utils.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; /** * Specialization class for requests @@ -64,9 +75,68 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { } } - public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) + public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { streamObserver = new OllamaChatStreamObserver(streamHandler); - return super.callSync(body); + return callSync(body); + } + + public OllamaChatResult callSync(OllamaChatRequest body) throws OllamaBaseException, IOException, InterruptedException { + // Create Request + HttpClient httpClient = HttpClient.newHttpClient(); + URI uri = URI.create(getHost() + getEndpointSuffix()); + HttpRequest.Builder requestBuilder = + getRequestBuilderDefault(uri) + .POST( + body.getBodyPublisher()); + HttpRequest request = requestBuilder.build(); + if (isVerbose()) LOG.info("Asking model: " + body.toString()); + HttpResponse response = + httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + + int statusCode = response.statusCode(); + InputStream responseBodyStream = response.body(); + StringBuilder responseBuffer = new StringBuilder(); + OllamaChatResponseModel ollamaChatResponseModel = null; + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + + String line; + while ((line = reader.readLine()) != null) { + if (statusCode == 404) { + LOG.warn("Status code: 404 (Not Found)"); + OllamaErrorResponse ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else if (statusCode == 401) { + LOG.warn("Status code: 401 (Unauthorized)"); + OllamaErrorResponse ollamaResponseModel = + Utils.getObjectMapper() + .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else if (statusCode == 400) { + LOG.warn("Status code: 400 (Bad Request)"); + OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, + OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else { + boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); + ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); + if (finished && body.stream) { + ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString()); + break; + } + } + } + } + if (statusCode != 200) { + LOG.error("Status code " + statusCode); + throw new OllamaBaseException(responseBuffer.toString()); + } else { + OllamaChatResult ollamaResult = + new OllamaChatResult(ollamaChatResponseModel,body.getMessages()); + if (isVerbose()) LOG.info("Model response: " + ollamaResult); + return ollamaResult; + } } } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java index 8529c18..e9d0e0d 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java @@ -6,6 +6,7 @@ import io.github.ollama4j.models.response.OllamaErrorResponse; import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.utils.OllamaRequestBody; import io.github.ollama4j.utils.Utils; +import lombok.Getter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -24,14 +25,15 @@ import java.util.Base64; /** * Abstract helperclass to call the ollama api server. */ +@Getter public abstract class OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class); - private String host; - private BasicAuth basicAuth; - private long requestTimeoutSeconds; - private boolean verbose; + private final String host; + private final BasicAuth basicAuth; + private final long requestTimeoutSeconds; + private final boolean verbose; public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { this.host = host; @@ -45,80 +47,13 @@ public abstract class OllamaEndpointCaller { protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); - /** - * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response. - * - * @param body POST body payload - * @return result answer given by the assistant - * @throws OllamaBaseException any response code than 200 has been returned - * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network issues happen - */ - public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException { - // Create Request - long startTime = System.currentTimeMillis(); - HttpClient httpClient = HttpClient.newHttpClient(); - URI uri = URI.create(this.host + getEndpointSuffix()); - HttpRequest.Builder requestBuilder = - getRequestBuilderDefault(uri) - .POST( - body.getBodyPublisher()); - HttpRequest request = requestBuilder.build(); - if (this.verbose) LOG.info("Asking model: " + body.toString()); - HttpResponse response = - httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); - - int statusCode = response.statusCode(); - InputStream responseBodyStream = response.body(); - StringBuilder responseBuffer = new StringBuilder(); - try (BufferedReader reader = - new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { - String line; - while ((line = reader.readLine()) != null) { - if (statusCode == 404) { - LOG.warn("Status code: 404 (Not Found)"); - OllamaErrorResponse ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); - responseBuffer.append(ollamaResponseModel.getError()); - } else if (statusCode == 401) { - LOG.warn("Status code: 401 (Unauthorized)"); - OllamaErrorResponse ollamaResponseModel = - Utils.getObjectMapper() - .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class); - responseBuffer.append(ollamaResponseModel.getError()); - } else if (statusCode == 400) { - LOG.warn("Status code: 400 (Bad Request)"); - OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, - OllamaErrorResponse.class); - responseBuffer.append(ollamaResponseModel.getError()); - } else { - boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); - if (finished) { - break; - } - } - } - } - - if (statusCode != 200) { - LOG.error("Status code " + statusCode); - throw new OllamaBaseException(responseBuffer.toString()); - } else { - long endTime = System.currentTimeMillis(); - OllamaResult ollamaResult = - new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); - if (verbose) LOG.info("Model response: " + ollamaResult); - return ollamaResult; - } - } - /** * Get default request builder. * * @param uri URI to get a HttpRequest.Builder * @return HttpRequest.Builder */ - private HttpRequest.Builder getRequestBuilderDefault(URI uri) { + protected HttpRequest.Builder getRequestBuilderDefault(URI uri) { HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri) .header("Content-Type", "application/json") @@ -134,7 +69,7 @@ public abstract class OllamaEndpointCaller { * * @return basic authentication header value (encoded credentials) */ - private String getBasicAuthHeaderValue() { + protected String getBasicAuthHeaderValue() { String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword(); return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); } @@ -144,7 +79,7 @@ public abstract class OllamaEndpointCaller { * * @return true when Basic Auth credentials set */ - private boolean isBasicAuthCredentialsSet() { + protected boolean isBasicAuthCredentialsSet() { return this.basicAuth != null; } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java index f4afb2c..00b2b12 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java @@ -2,6 +2,7 @@ package io.github.ollama4j.models.request; import com.fasterxml.jackson.core.JsonProcessingException; import io.github.ollama4j.exceptions.OllamaBaseException; +import io.github.ollama4j.models.response.OllamaErrorResponse; import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.generate.OllamaGenerateResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver; @@ -11,7 +12,15 @@ import io.github.ollama4j.utils.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { @@ -46,6 +55,73 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { streamObserver = new OllamaGenerateStreamObserver(streamHandler); - return super.callSync(body); + return callSync(body); + } + + /** + * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response. + * + * @param body POST body payload + * @return result answer given by the assistant + * @throws OllamaBaseException any response code than 200 has been returned + * @throws IOException in case the responseStream can not be read + * @throws InterruptedException in case the server is not reachable or network issues happen + */ + public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException { + // Create Request + long startTime = System.currentTimeMillis(); + HttpClient httpClient = HttpClient.newHttpClient(); + URI uri = URI.create(getHost() + getEndpointSuffix()); + HttpRequest.Builder requestBuilder = + getRequestBuilderDefault(uri) + .POST( + body.getBodyPublisher()); + HttpRequest request = requestBuilder.build(); + if (isVerbose()) LOG.info("Asking model: " + body.toString()); + HttpResponse response = + httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + + int statusCode = response.statusCode(); + InputStream responseBodyStream = response.body(); + StringBuilder responseBuffer = new StringBuilder(); + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + if (statusCode == 404) { + LOG.warn("Status code: 404 (Not Found)"); + OllamaErrorResponse ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else if (statusCode == 401) { + LOG.warn("Status code: 401 (Unauthorized)"); + OllamaErrorResponse ollamaResponseModel = + Utils.getObjectMapper() + .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else if (statusCode == 400) { + LOG.warn("Status code: 400 (Bad Request)"); + OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, + OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else { + boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); + if (finished) { + break; + } + } + } + } + + if (statusCode != 200) { + LOG.error("Status code " + statusCode); + throw new OllamaBaseException(responseBuffer.toString()); + } else { + long endTime = System.currentTimeMillis(); + OllamaResult ollamaResult = + new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); + if (isVerbose()) LOG.info("Model response: " + ollamaResult); + return ollamaResult; + } } } diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java index 8465cb6..beb01ec 100644 --- a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java +++ b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java @@ -17,7 +17,7 @@ public class OllamaResult { * * @return String completion/response text */ - private final String content; + private final String response; /** * -- GETTER -- @@ -35,8 +35,8 @@ public class OllamaResult { */ private long responseTime = 0; - public OllamaResult(String content, long responseTime, int httpStatusCode) { - this.content = content; + public OllamaResult(String response, long responseTime, int httpStatusCode) { + this.response = response; this.responseTime = responseTime; this.httpStatusCode = httpStatusCode; } diff --git a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java index 702d0a2..d18187b 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/ollama4j/integrationtests/TestRealAPIs.java @@ -2,12 +2,9 @@ package io.github.ollama4j.integrationtests; import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.exceptions.OllamaBaseException; +import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.response.ModelDetail; -import io.github.ollama4j.models.chat.OllamaChatRequest; import io.github.ollama4j.models.response.OllamaResult; -import io.github.ollama4j.models.chat.OllamaChatMessageRole; -import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; -import io.github.ollama4j.models.chat.OllamaChatResult; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.ollama4j.tools.ToolFunction; @@ -47,6 +44,7 @@ class TestRealAPIs { config = new Config(); ollamaAPI = new OllamaAPI(config.getOllamaURL()); ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds()); + ollamaAPI.setVerbose(true); } @Test @@ -196,7 +194,9 @@ class TestRealAPIs { OllamaChatResult chatResult = ollamaAPI.chat(requestModel); assertNotNull(chatResult); - assertFalse(chatResult.getResponse().isBlank()); + assertNotNull(chatResult.getResponse()); + assertNotNull(chatResult.getResponse().getMessage()); + assertFalse(chatResult.getResponse().getMessage().getContent().isBlank()); assertEquals(4, chatResult.getChatHistory().size()); } catch (IOException | OllamaBaseException | InterruptedException e) { fail(e); @@ -217,8 +217,10 @@ class TestRealAPIs { OllamaChatResult chatResult = ollamaAPI.chat(requestModel); assertNotNull(chatResult); - assertFalse(chatResult.getResponse().isBlank()); - assertTrue(chatResult.getResponse().startsWith("NI")); + assertNotNull(chatResult.getResponse()); + assertNotNull(chatResult.getResponse().getMessage()); + assertFalse(chatResult.getResponse().getMessage().getContent().isBlank()); + assertTrue(chatResult.getResponse().getMessage().getContent().startsWith("NI")); assertEquals(3, chatResult.getChatHistory().size()); } catch (IOException | OllamaBaseException | InterruptedException e) { fail(e); @@ -267,9 +269,17 @@ class TestRealAPIs { .build(); OllamaChatResult chatResult = ollamaAPI.chat(requestModel); - System.err.println("Response: " + chatResult); assertNotNull(chatResult); - assertFalse(chatResult.getResponse().isBlank()); + assertNotNull(chatResult.getResponse()); + assertNotNull(chatResult.getResponse().getMessage()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponse().getMessage().getRole().getRoleName()); + List toolCalls = chatResult.getResponse().getMessage().getToolCalls(); + assertEquals(1, toolCalls.size()); + assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName()); + assertEquals(1, toolCalls.get(0).getFunction().getArguments().size()); + String employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name"); + assertNotNull(employeeName); + assertEquals("Rahul Kumar",employeeName); assertEquals(2, chatResult.getChatHistory().size()); } catch (IOException | OllamaBaseException | InterruptedException e) { fail(e); @@ -295,7 +305,10 @@ class TestRealAPIs { sb.append(substring); }); assertNotNull(chatResult); - assertEquals(sb.toString().trim(), chatResult.getResponse().trim()); + assertNotNull(chatResult.getResponse()); + assertNotNull(chatResult.getResponse().getMessage()); + assertNotNull(chatResult.getResponse().getMessage().getContent()); + assertEquals(sb.toString().trim(), chatResult.getResponse().getMessage().getContent().trim()); } catch (IOException | OllamaBaseException | InterruptedException e) { fail(e); } @@ -309,7 +322,7 @@ class TestRealAPIs { OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel()); OllamaChatRequest requestModel = - builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", + builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(), List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); OllamaChatResult chatResult = ollamaAPI.chat(requestModel); @@ -338,7 +351,7 @@ class TestRealAPIs { testEndpointReachability(); try { OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel()); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(), "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") .build(); diff --git a/src/test/java/io/github/ollama4j/unittests/jackson/TestChatRequestSerialization.java b/src/test/java/io/github/ollama4j/unittests/jackson/TestChatRequestSerialization.java index 2ce210c..db33889 100644 --- a/src/test/java/io/github/ollama4j/unittests/jackson/TestChatRequestSerialization.java +++ b/src/test/java/io/github/ollama4j/unittests/jackson/TestChatRequestSerialization.java @@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrowsExactly; import java.io.File; +import java.util.Collections; import java.util.List; import io.github.ollama4j.models.chat.OllamaChatRequest; @@ -42,7 +43,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest