extends ollamaChatResult to have full access to OllamaChatResult

This commit is contained in:
Markus Klenke 2024-12-07 00:29:09 +01:00
parent 12bb10392e
commit 25694a8bc9
10 changed files with 218 additions and 106 deletions

View File

@ -602,7 +602,7 @@ public class OllamaAPI {
OllamaResult result = generate(model, prompt, raw, options, null); OllamaResult result = generate(model, prompt, raw, options, null);
toolResult.setModelResult(result); toolResult.setModelResult(result);
String toolsResponse = result.getContent(); String toolsResponse = result.getResponse();
if (toolsResponse.contains("[TOOL_CALLS]")) { if (toolsResponse.contains("[TOOL_CALLS]")) {
toolsResponse = toolsResponse.replace("[TOOL_CALLS]", ""); toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
} }
@ -767,7 +767,7 @@ public class OllamaAPI {
*/ */
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
OllamaResult result; OllamaChatResult result;
// add all registered tools to Request // add all registered tools to Request
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList())); request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
@ -779,7 +779,7 @@ public class OllamaAPI {
result = requestCaller.callSync(request); result = requestCaller.callSync(request);
} }
return new OllamaChatResult(result.getContent(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); return result;
} }
public void registerTool(Tools.ToolSpecification toolSpecification) { public void registerTool(Tools.ToolSpecification toolSpecification) {

View File

@ -10,6 +10,7 @@ import java.io.IOException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.nio.file.Files; import java.nio.file.Files;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -38,6 +39,10 @@ public class OllamaChatRequestBuilder {
request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){
return withMessage(role,content, Collections.emptyList());
}
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) { public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();

View File

@ -2,28 +2,40 @@ package io.github.ollama4j.models.chat;
import java.util.List; import java.util.List;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
import lombok.Getter;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
/** /**
* Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the * Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
* {@link OllamaChatMessageRole#ASSISTANT} role. * {@link OllamaChatMessageRole#ASSISTANT} role.
*/ */
public class OllamaChatResult extends OllamaResult { @Getter
public class OllamaChatResult {
private List<OllamaChatMessage> chatHistory; private List<OllamaChatMessage> chatHistory;
public OllamaChatResult(String response, long responseTime, int httpStatusCode, List<OllamaChatMessage> chatHistory) { private OllamaChatResponseModel response;
super(response, responseTime, httpStatusCode);
public OllamaChatResult(OllamaChatResponseModel response, List<OllamaChatMessage> chatHistory) {
this.chatHistory = chatHistory; this.chatHistory = chatHistory;
this.response = response;
appendAnswerToChatHistory(response); appendAnswerToChatHistory(response);
} }
public List<OllamaChatMessage> getChatHistory() { private void appendAnswerToChatHistory(OllamaChatResponseModel response) {
return chatHistory; this.chatHistory.add(response.getMessage());
} }
private void appendAnswerToChatHistory(String answer) { @Override
OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer); public String toString() {
this.chatHistory.add(assistantMessage); try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
} }
} }

View File

@ -4,6 +4,9 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.chat.OllamaChatMessage; import io.github.ollama4j.models.chat.OllamaChatMessage;
import io.github.ollama4j.models.chat.OllamaChatRequest;
import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.models.response.OllamaErrorResponse;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.models.chat.OllamaChatResponseModel; import io.github.ollama4j.models.chat.OllamaChatResponseModel;
import io.github.ollama4j.models.chat.OllamaChatStreamObserver; import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
@ -13,7 +16,15 @@ import io.github.ollama4j.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
/** /**
* Specialization class for requests * Specialization class for requests
@ -64,9 +75,68 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
} }
} }
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaChatStreamObserver(streamHandler); streamObserver = new OllamaChatStreamObserver(streamHandler);
return super.callSync(body); return callSync(body);
}
public OllamaChatResult callSync(OllamaChatRequest body) throws OllamaBaseException, IOException, InterruptedException {
// Create Request
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(getHost() + getEndpointSuffix());
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
if (isVerbose()) LOG.info("Asking model: " + body.toString());
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
OllamaChatResponseModel ollamaChatResponseModel = null;
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)");
OllamaErrorResponse ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponse ollamaResponseModel =
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)");
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
if (finished && body.stream) {
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
break;
}
}
}
}
if (statusCode != 200) {
LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
OllamaChatResult ollamaResult =
new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
return ollamaResult;
}
} }
} }

View File

@ -6,6 +6,7 @@ import io.github.ollama4j.models.response.OllamaErrorResponse;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.utils.OllamaRequestBody; import io.github.ollama4j.utils.OllamaRequestBody;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import lombok.Getter;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -24,14 +25,15 @@ import java.util.Base64;
/** /**
* Abstract helperclass to call the ollama api server. * Abstract helperclass to call the ollama api server.
*/ */
@Getter
public abstract class OllamaEndpointCaller { public abstract class OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
private String host; private final String host;
private BasicAuth basicAuth; private final BasicAuth basicAuth;
private long requestTimeoutSeconds; private final long requestTimeoutSeconds;
private boolean verbose; private final boolean verbose;
public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
this.host = host; this.host = host;
@ -45,80 +47,13 @@ public abstract class OllamaEndpointCaller {
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
/**
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
*
* @param body POST body payload
* @return result answer given by the assistant
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
// Create Request
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + getEndpointSuffix());
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
if (this.verbose) LOG.info("Asking model: " + body.toString());
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)");
OllamaErrorResponse ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponse ollamaResponseModel =
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)");
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
if (finished) {
break;
}
}
}
}
if (statusCode != 200) {
LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
if (verbose) LOG.info("Model response: " + ollamaResult);
return ollamaResult;
}
}
/** /**
* Get default request builder. * Get default request builder.
* *
* @param uri URI to get a HttpRequest.Builder * @param uri URI to get a HttpRequest.Builder
* @return HttpRequest.Builder * @return HttpRequest.Builder
*/ */
private HttpRequest.Builder getRequestBuilderDefault(URI uri) { protected HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder = HttpRequest.Builder requestBuilder =
HttpRequest.newBuilder(uri) HttpRequest.newBuilder(uri)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
@ -134,7 +69,7 @@ public abstract class OllamaEndpointCaller {
* *
* @return basic authentication header value (encoded credentials) * @return basic authentication header value (encoded credentials)
*/ */
private String getBasicAuthHeaderValue() { protected String getBasicAuthHeaderValue() {
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword(); String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
} }
@ -144,7 +79,7 @@ public abstract class OllamaEndpointCaller {
* *
* @return true when Basic Auth credentials set * @return true when Basic Auth credentials set
*/ */
private boolean isBasicAuthCredentialsSet() { protected boolean isBasicAuthCredentialsSet() {
return this.basicAuth != null; return this.basicAuth != null;
} }

View File

@ -2,6 +2,7 @@ package io.github.ollama4j.models.request;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.response.OllamaErrorResponse;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver; import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
@ -11,7 +12,15 @@ import io.github.ollama4j.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
@ -46,6 +55,73 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaGenerateStreamObserver(streamHandler); streamObserver = new OllamaGenerateStreamObserver(streamHandler);
return super.callSync(body); return callSync(body);
}
/**
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
*
* @param body POST body payload
* @return result answer given by the assistant
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
// Create Request
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(getHost() + getEndpointSuffix());
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
if (isVerbose()) LOG.info("Asking model: " + body.toString());
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)");
OllamaErrorResponse ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponse ollamaResponseModel =
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)");
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
if (finished) {
break;
}
}
}
}
if (statusCode != 200) {
LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
return ollamaResult;
}
} }
} }

View File

@ -17,7 +17,7 @@ public class OllamaResult {
* *
* @return String completion/response text * @return String completion/response text
*/ */
private final String content; private final String response;
/** /**
* -- GETTER -- * -- GETTER --
@ -35,8 +35,8 @@ public class OllamaResult {
*/ */
private long responseTime = 0; private long responseTime = 0;
public OllamaResult(String content, long responseTime, int httpStatusCode) { public OllamaResult(String response, long responseTime, int httpStatusCode) {
this.content = content; this.response = response;
this.responseTime = responseTime; this.responseTime = responseTime;
this.httpStatusCode = httpStatusCode; this.httpStatusCode = httpStatusCode;
} }

View File

@ -2,12 +2,9 @@ package io.github.ollama4j.integrationtests;
import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.response.ModelDetail; import io.github.ollama4j.models.response.ModelDetail;
import io.github.ollama4j.models.chat.OllamaChatRequest;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.ollama4j.tools.ToolFunction; import io.github.ollama4j.tools.ToolFunction;
@ -47,6 +44,7 @@ class TestRealAPIs {
config = new Config(); config = new Config();
ollamaAPI = new OllamaAPI(config.getOllamaURL()); ollamaAPI = new OllamaAPI(config.getOllamaURL());
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds()); ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
ollamaAPI.setVerbose(true);
} }
@Test @Test
@ -196,7 +194,9 @@ class TestRealAPIs {
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank()); assertNotNull(chatResult.getResponse());
assertNotNull(chatResult.getResponse().getMessage());
assertFalse(chatResult.getResponse().getMessage().getContent().isBlank());
assertEquals(4, chatResult.getChatHistory().size()); assertEquals(4, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e); fail(e);
@ -217,8 +217,10 @@ class TestRealAPIs {
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank()); assertNotNull(chatResult.getResponse());
assertTrue(chatResult.getResponse().startsWith("NI")); assertNotNull(chatResult.getResponse().getMessage());
assertFalse(chatResult.getResponse().getMessage().getContent().isBlank());
assertTrue(chatResult.getResponse().getMessage().getContent().startsWith("NI"));
assertEquals(3, chatResult.getChatHistory().size()); assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e); fail(e);
@ -267,9 +269,17 @@ class TestRealAPIs {
.build(); .build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
System.err.println("Response: " + chatResult);
assertNotNull(chatResult); assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank()); assertNotNull(chatResult.getResponse());
assertNotNull(chatResult.getResponse().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponse().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getResponse().getMessage().getToolCalls();
assertEquals(1, toolCalls.size());
assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName());
assertEquals(1, toolCalls.get(0).getFunction().getArguments().size());
String employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name");
assertNotNull(employeeName);
assertEquals("Rahul Kumar",employeeName);
assertEquals(2, chatResult.getChatHistory().size()); assertEquals(2, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e); fail(e);
@ -295,7 +305,10 @@ class TestRealAPIs {
sb.append(substring); sb.append(substring);
}); });
assertNotNull(chatResult); assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim()); assertNotNull(chatResult.getResponse());
assertNotNull(chatResult.getResponse().getMessage());
assertNotNull(chatResult.getResponse().getMessage().getContent());
assertEquals(sb.toString().trim(), chatResult.getResponse().getMessage().getContent().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e); fail(e);
} }
@ -309,7 +322,7 @@ class TestRealAPIs {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder builder =
OllamaChatRequestBuilder.getInstance(config.getImageModel()); OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequest requestModel = OllamaChatRequest requestModel =
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
@ -338,7 +351,7 @@ class TestRealAPIs {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel()); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build(); .build();

View File

@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrowsExactly; import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import java.io.File; import java.io.File;
import java.util.Collections;
import java.util.List; import java.util.List;
import io.github.ollama4j.models.chat.OllamaChatRequest; import io.github.ollama4j.models.chat.OllamaChatRequest;
@ -42,7 +43,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
@Test @Test
public void testRequestWithMessageAndImage() { public void testRequestWithMessageAndImage() {
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", Collections.emptyList(),
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build(); List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serialize(req); String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req); assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);

View File

@ -1,4 +1,4 @@
ollama.url=http://localhost:11434 ollama.url=http://localhost:11434
ollama.model=qwen:0.5b ollama.model=llama3.2:1b
ollama.model.image=llava ollama.model.image=llava:latest
ollama.request-timeout-seconds=120 ollama.request-timeout-seconds=120