Enhance OllamaAPI with 'think' parameter and response handling

- Added 'think' parameter to the generate methods in OllamaAPI to enable step-by-step reasoning for model responses.
- Updated OllamaGenerateRequest and OllamaGenerateResponseModel to include 'thinking' field.
- Modified response handling in OllamaGenerateStreamObserver to incorporate 'thinking' responses.
- Updated integration tests to validate the new functionality, including tests for generating responses with thinking enabled.
- Refactored related methods and classes for consistency and clarity.
This commit is contained in:
amithkoujalgi 2025-08-28 10:03:07 +05:30
parent 3efd7712be
commit 14642e9856
No known key found for this signature in database
GPG Key ID: E29A37746AF94B70
13 changed files with 342 additions and 399 deletions

View File

@ -749,6 +749,8 @@ public class OllamaAPI {
* *
* @param model the ollama model to ask the question to * @param model the ollama model to ask the question to
* @param prompt the prompt/question text * @param prompt the prompt/question text
* @param raw if true no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API
* @param think if true the model will "think" step-by-step before generating the final response
* @param options the Options object - <a * @param options the Options object - <a
* href= * href=
* "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More * "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
@ -761,14 +763,42 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generate(String model, String prompt, boolean raw, Options options, public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw); ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setThink(think);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
} }
/**
* Generates response using the specified AI model and prompt (in blocking
* mode).
* <p>
* Uses {@link #generate(String, String, boolean, boolean, Options, OllamaStreamHandler)}
*
* @param model The name or identifier of the AI model to use for generating
* the response.
* @param prompt The input text or prompt to provide to the AI model.
* @param raw In some cases, you may wish to bypass the templating system
* and provide a full prompt. In this case, you can use the raw
* parameter to disable templating. Also note that raw mode will
* not return a context.
* @param think If set to true, the model will "think" step-by-step before
* generating the final response.
* @param options Additional options or configurations to use when generating
* the response.
* @return {@link OllamaResult}
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options)
throws OllamaBaseException, IOException, InterruptedException {
return generate(model, prompt, raw, think, options, null);
}
/** /**
* Generates structured output from the specified AI model and prompt. * Generates structured output from the specified AI model and prompt.
* *
@ -809,7 +839,7 @@ public class OllamaAPI {
if (statusCode == 200) { if (statusCode == 200) {
OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody,
OllamaStructuredResult.class); OllamaStructuredResult.class);
OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(),
structuredResult.getResponseTime(), statusCode); structuredResult.getResponseTime(), statusCode);
return ollamaResult; return ollamaResult;
} else { } else {
@ -817,31 +847,6 @@ public class OllamaAPI {
} }
} }
/**
* Generates response using the specified AI model and prompt (in blocking
* mode).
* <p>
* Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
*
* @param model The name or identifier of the AI model to use for generating
* the response.
* @param prompt The input text or prompt to provide to the AI model.
* @param raw In some cases, you may wish to bypass the templating system
* and provide a full prompt. In this case, you can use the raw
* parameter to disable templating. Also note that raw mode will
* not return a context.
* @param options Additional options or configurations to use when generating
* the response.
* @return {@link OllamaResult}
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generate(String model, String prompt, boolean raw, Options options)
throws OllamaBaseException, IOException, InterruptedException {
return generate(model, prompt, raw, options, null);
}
/** /**
* Generates response using the specified AI model and prompt (in blocking * Generates response using the specified AI model and prompt (in blocking
* mode), and then invokes a set of tools * mode), and then invokes a set of tools
@ -850,6 +855,8 @@ public class OllamaAPI {
* @param model The name or identifier of the AI model to use for generating * @param model The name or identifier of the AI model to use for generating
* the response. * the response.
* @param prompt The input text or prompt to provide to the AI model. * @param prompt The input text or prompt to provide to the AI model.
* @param think If set to true, the model will "think" step-by-step before
* generating the final response.
* @param options Additional options or configurations to use when generating * @param options Additional options or configurations to use when generating
* the response. * the response.
* @return {@link OllamaToolsResult} An OllamaToolsResult object containing the * @return {@link OllamaToolsResult} An OllamaToolsResult object containing the
@ -859,7 +866,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaToolsResult generateWithTools(String model, String prompt, Options options) public OllamaToolsResult generateWithTools(String model, String prompt, boolean think, Options options)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
boolean raw = true; boolean raw = true;
OllamaToolsResult toolResult = new OllamaToolsResult(); OllamaToolsResult toolResult = new OllamaToolsResult();
@ -874,7 +881,7 @@ public class OllamaAPI {
prompt = promptBuilder.build(); prompt = promptBuilder.build();
} }
OllamaResult result = generate(model, prompt, raw, options, null); OllamaResult result = generate(model, prompt, raw, think, options, null);
toolResult.setModelResult(result); toolResult.setModelResult(result);
String toolsResponse = result.getResponse(); String toolsResponse = result.getResponse();
@ -1023,19 +1030,25 @@ public class OllamaAPI {
/** /**
* Synchronously generates a response using a list of image byte arrays. * Synchronously generates a response using a list of image byte arrays.
* <p> * <p>
* This method encodes the provided byte arrays into Base64 and sends them to the Ollama server. * This method encodes the provided byte arrays into Base64 and sends them to
* the Ollama server.
* *
* @param model the Ollama model to use for generating the response * @param model the Ollama model to use for generating the response
* @param prompt the prompt or question text to send to the model * @param prompt the prompt or question text to send to the model
* @param images the list of image data as byte arrays * @param images the list of image data as byte arrays
* @param options the Options object - <a href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More details on the options</a> * @param options the Options object - <a href=
* @param streamHandler optional callback that will be invoked with each streamed response; if null, streaming is disabled * "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
* @return OllamaResult containing the response text and the time taken for the response * details on the options</a>
* @param streamHandler optional callback that will be invoked with each
* streamed response; if null, streaming is disabled
* @return OllamaResult containing the response text and the time taken for the
* response
* @throws OllamaBaseException if the response indicates an error status * @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
List<String> encodedImages = new ArrayList<>(); List<String> encodedImages = new ArrayList<>();
for (byte[] image : images) { for (byte[] image : images) {
encodedImages.add(encodeByteArrayToBase64(image)); encodedImages.add(encodeByteArrayToBase64(image));
@ -1046,15 +1059,18 @@ public class OllamaAPI {
} }
/** /**
* Convenience method to call the Ollama API using image byte arrays without streaming responses. * Convenience method to call the Ollama API using image byte arrays without
* streaming responses.
* <p> * <p>
* Uses {@link #generateWithImages(String, String, List, Options, OllamaStreamHandler)} * Uses
* {@link #generateWithImages(String, String, List, Options, OllamaStreamHandler)}
* *
* @throws OllamaBaseException if the response indicates an error status * @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options) throws OllamaBaseException, IOException, InterruptedException { public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options)
throws OllamaBaseException, IOException, InterruptedException {
return generateWithImages(model, prompt, images, options, null); return generateWithImages(model, prompt, images, options, null);
} }
@ -1069,10 +1085,12 @@ public class OllamaAPI {
* history including the newly acquired assistant response. * history including the newly acquired assistant response.
* @throws OllamaBaseException any response code than 200 has been returned * @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network * @throws InterruptedException in case the server is not reachable or
* network
* issues happen * issues happen
* @throws OllamaBaseException if the response indicates an error status * @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP
* request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails * @throws ToolInvocationException if the tool invocation fails
*/ */
@ -1092,10 +1110,12 @@ public class OllamaAPI {
* @return {@link OllamaChatResult} * @return {@link OllamaChatResult}
* @throws OllamaBaseException any response code than 200 has been returned * @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network * @throws InterruptedException in case the server is not reachable or
* network
* issues happen * issues happen
* @throws OllamaBaseException if the response indicates an error status * @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP
* request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails * @throws ToolInvocationException if the tool invocation fails
*/ */
@ -1117,10 +1137,12 @@ public class OllamaAPI {
* @return {@link OllamaChatResult} * @return {@link OllamaChatResult}
* @throws OllamaBaseException any response code than 200 has been returned * @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network * @throws InterruptedException in case the server is not reachable or
* network
* issues happen * issues happen
* @throws OllamaBaseException if the response indicates an error status * @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP
* request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails * @throws ToolInvocationException if the tool invocation fails
*/ */

View File

@ -39,11 +39,17 @@ public class OllamaChatRequestBuilder {
request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){ public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content) {
return withMessage(role,content, Collections.emptyList()); return withMessage(role, content, Collections.emptyList());
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) { public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls) {
List<OllamaChatMessage> messages = this.request.getMessages();
messages.add(new OllamaChatMessage(role, content, toolCalls, null));
return this;
}
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls, List<File> images) {
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = images.stream().map(file -> { List<byte[]> binaryImages = images.stream().map(file -> {
@ -55,11 +61,11 @@ public class OllamaChatRequestBuilder {
} }
}).collect(Collectors.toList()); }).collect(Collectors.toList());
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages)); messages.add(new OllamaChatMessage(role, content, toolCalls, binaryImages));
return this; return this;
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List<OllamaChatToolCalls> toolCalls, String... imageUrls) { public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls, String... imageUrls) {
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = null; List<byte[]> binaryImages = null;
if (imageUrls.length > 0) { if (imageUrls.length > 0) {
@ -75,7 +81,7 @@ public class OllamaChatRequestBuilder {
} }
} }
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages)); messages.add(new OllamaChatMessage(role, content, toolCalls, binaryImages));
return this; return this;
} }

View File

@ -19,6 +19,7 @@ public class OllamaGenerateRequest extends OllamaCommonRequest implements Ollama
private String system; private String system;
private String context; private String context;
private boolean raw; private boolean raw;
private boolean think;
public OllamaGenerateRequest() { public OllamaGenerateRequest() {
} }

View File

@ -12,6 +12,7 @@ public class OllamaGenerateResponseModel {
private String model; private String model;
private @JsonProperty("created_at") String createdAt; private @JsonProperty("created_at") String createdAt;
private String response; private String response;
private String thinking;
private boolean done; private boolean done;
private List<Integer> context; private List<Integer> context;
private @JsonProperty("total_duration") Long totalDuration; private @JsonProperty("total_duration") Long totalDuration;

View File

@ -21,9 +21,17 @@ public class OllamaGenerateStreamObserver {
} }
protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart) { protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart) {
message = message + currentResponsePart.getResponse(); String response = currentResponsePart.getResponse();
String thinking = currentResponsePart.getThinking();
boolean hasResponse = response != null && !response.trim().isEmpty();
boolean hasThinking = thinking != null && !thinking.trim().isEmpty();
if (!hasResponse && hasThinking) {
message = message + thinking;
} else if (hasResponse) {
message = message + response;
}
streamHandler.accept(message); streamHandler.accept(message);
} }
} }

View File

@ -51,13 +51,13 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
* @return TRUE, if ollama-Response has 'done' state * @return TRUE, if ollama-Response has 'done' state
*/ */
@Override @Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) {
try { try {
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
// it seems that under heavy load ollama responds with an empty chat message part in the streamed response // it seems that under heavy load ollama responds with an empty chat message part in the streamed response
// thus, we null check the message and hope that the next streamed response has some message content again // thus, we null check the message and hope that the next streamed response has some message content again
OllamaChatMessage message = ollamaResponseModel.getMessage(); OllamaChatMessage message = ollamaResponseModel.getMessage();
if(message != null) { if (message != null) {
responseBuffer.append(message.getContent()); responseBuffer.append(message.getContent());
if (tokenHandler != null) { if (tokenHandler != null) {
tokenHandler.accept(ollamaResponseModel); tokenHandler.accept(ollamaResponseModel);
@ -92,6 +92,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
int statusCode = response.statusCode(); int statusCode = response.statusCode();
InputStream responseBodyStream = response.body(); InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder(); StringBuilder responseBuffer = new StringBuilder();
StringBuilder thinkingBuffer = new StringBuilder();
OllamaChatResponseModel ollamaChatResponseModel = null; OllamaChatResponseModel ollamaChatResponseModel = null;
List<OllamaChatToolCalls> wantedToolsForStream = null; List<OllamaChatToolCalls> wantedToolsForStream = null;
try (BufferedReader reader = try (BufferedReader reader =
@ -115,10 +116,15 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponse.class); OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 500) {
LOG.warn("Status code: 500 (Internal Server Error)");
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else { } else {
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); boolean finished = parseResponseAndAddToBuffer(line, responseBuffer, thinkingBuffer);
ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){ if (body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null) {
wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls(); wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
} }
if (finished && body.stream) { if (finished && body.stream) {
@ -132,11 +138,11 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
LOG.error("Status code " + statusCode); LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString()); throw new OllamaBaseException(responseBuffer.toString());
} else { } else {
if(wantedToolsForStream != null) { if (wantedToolsForStream != null) {
ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream); ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
} }
OllamaChatResult ollamaResult = OllamaChatResult ollamaResult =
new OllamaChatResult(ollamaChatResponseModel,body.getMessages()); new OllamaChatResult(ollamaChatResponseModel, body.getMessages());
if (isVerbose()) LOG.info("Model response: " + ollamaResult); if (isVerbose()) LOG.info("Model response: " + ollamaResult);
return ollamaResult; return ollamaResult;
} }

View File

@ -32,7 +32,7 @@ public abstract class OllamaEndpointCaller {
protected abstract String getEndpointSuffix(); protected abstract String getEndpointSuffix();
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer);
/** /**

View File

@ -38,10 +38,15 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
} }
@Override @Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) {
try { try {
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
if (ollamaResponseModel.getResponse() != null) {
responseBuffer.append(ollamaResponseModel.getResponse()); responseBuffer.append(ollamaResponseModel.getResponse());
}
if (ollamaResponseModel.getThinking() != null) {
thinkingBuffer.append(ollamaResponseModel.getThinking());
}
if (streamObserver != null) { if (streamObserver != null) {
streamObserver.notify(ollamaResponseModel); streamObserver.notify(ollamaResponseModel);
} }
@ -84,6 +89,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
int statusCode = response.statusCode(); int statusCode = response.statusCode();
InputStream responseBodyStream = response.body(); InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder(); StringBuilder responseBuffer = new StringBuilder();
StringBuilder thinkingBuffer = new StringBuilder();
try (BufferedReader reader = try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line; String line;
@ -105,7 +111,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
OllamaErrorResponse.class); OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else { } else {
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); boolean finished = parseResponseAndAddToBuffer(line, responseBuffer, thinkingBuffer);
if (finished) { if (finished) {
break; break;
} }
@ -119,7 +125,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
} else { } else {
long endTime = System.currentTimeMillis(); long endTime = System.currentTimeMillis();
OllamaResult ollamaResult = OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); new OllamaResult(responseBuffer.toString().trim(), thinkingBuffer.toString().trim(), endTime - startTime, statusCode);
if (isVerbose()) LOG.info("Model response: " + ollamaResult); if (isVerbose()) LOG.info("Model response: " + ollamaResult);
return ollamaResult; return ollamaResult;
} }

View File

@ -12,7 +12,9 @@ import static io.github.ollama4j.utils.Utils.getObjectMapper;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
/** The type Ollama result. */ /**
* The type Ollama result.
*/
@Getter @Getter
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Data @Data
@ -25,6 +27,7 @@ public class OllamaResult {
* @return String completion/response text * @return String completion/response text
*/ */
private final String response; private final String response;
private final String thinking;
/** /**
* -- GETTER -- * -- GETTER --
@ -42,8 +45,9 @@ public class OllamaResult {
*/ */
private long responseTime = 0; private long responseTime = 0;
public OllamaResult(String response, long responseTime, int httpStatusCode) { public OllamaResult(String response, String thinking, long responseTime, int httpStatusCode) {
this.response = response; this.response = response;
this.thinking = thinking;
this.responseTime = responseTime; this.responseTime = responseTime;
this.httpStatusCode = httpStatusCode; this.httpStatusCode = httpStatusCode;
} }
@ -53,6 +57,7 @@ public class OllamaResult {
try { try {
Map<String, Object> responseMap = new HashMap<>(); Map<String, Object> responseMap = new HashMap<>();
responseMap.put("response", this.response); responseMap.put("response", this.response);
responseMap.put("thinking", this.thinking);
responseMap.put("httpStatusCode", this.httpStatusCode); responseMap.put("httpStatusCode", this.httpStatusCode);
responseMap.put("responseTime", this.responseTime); responseMap.put("responseTime", this.responseTime);
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap); return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap);

View File

@ -21,6 +21,7 @@ import lombok.NoArgsConstructor;
@JsonIgnoreProperties(ignoreUnknown = true) @JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaStructuredResult { public class OllamaStructuredResult {
private String response; private String response;
private String thinking;
private int httpStatusCode; private int httpStatusCode;

View File

@ -52,18 +52,19 @@ public class OllamaAPIIntegrationTest {
private static final String CHAT_MODEL_SYSTEM_PROMPT = "llama3.2:1b"; private static final String CHAT_MODEL_SYSTEM_PROMPT = "llama3.2:1b";
private static final String CHAT_MODEL_LLAMA3 = "llama3"; private static final String CHAT_MODEL_LLAMA3 = "llama3";
private static final String IMAGE_MODEL_LLAVA = "llava"; private static final String IMAGE_MODEL_LLAVA = "llava";
private static final String THINKING_MODEL_GPT_OSS = "gpt-oss:20b";
@BeforeAll @BeforeAll
public static void setUp() { public static void setUp() {
try { try {
boolean useExternalOllamaHost = Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST")); boolean useExternalOllamaHost = Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST"));
String ollamaHost = System.getenv("OLLAMA_HOST"); String ollamaHost = System.getenv("OLLAMA_HOST");
if (useExternalOllamaHost) { if (useExternalOllamaHost) {
LOG.info("Using external Ollama host..."); LOG.info("Using external Ollama host...");
api = new OllamaAPI(ollamaHost); api = new OllamaAPI(ollamaHost);
} else { } else {
throw new RuntimeException( throw new RuntimeException("USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers Ollama host for the tests now. If you would like to use an external host, please set the env var to USE_EXTERNAL_OLLAMA_HOST=true and set the env var OLLAMA_HOST=http://localhost:11435 or a different host/port.");
"USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers Ollama host for the tests now. If you would like to use an external host, please set the env var to USE_EXTERNAL_OLLAMA_HOST=true and set the env var OLLAMA_HOST=http://localhost:11435 or a different host/port.");
} }
} catch (Exception e) { } catch (Exception e) {
String ollamaVersion = "0.6.1"; String ollamaVersion = "0.6.1";
@ -102,8 +103,7 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(2) @Order(2)
public void testListModelsAPI() public void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
api.pullModel(EMBEDDING_MODEL_MINILM); api.pullModel(EMBEDDING_MODEL_MINILM);
// Fetch the list of models // Fetch the list of models
List<Model> models = api.listModels(); List<Model> models = api.listModels();
@ -115,8 +115,7 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(2) @Order(2)
void testListModelsFromLibrary() void testListModelsFromLibrary() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
List<LibraryModel> models = api.listModelsFromLibrary(); List<LibraryModel> models = api.listModelsFromLibrary();
assertNotNull(models); assertNotNull(models);
assertFalse(models.isEmpty()); assertFalse(models.isEmpty());
@ -124,8 +123,7 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(3) @Order(3)
public void testPullModelAPI() public void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
api.pullModel(EMBEDDING_MODEL_MINILM); api.pullModel(EMBEDDING_MODEL_MINILM);
List<Model> models = api.listModels(); List<Model> models = api.listModels();
assertNotNull(models, "Models should not be null"); assertNotNull(models, "Models should not be null");
@ -145,16 +143,14 @@ public class OllamaAPIIntegrationTest {
@Order(5) @Order(5)
public void testEmbeddings() throws Exception { public void testEmbeddings() throws Exception {
api.pullModel(EMBEDDING_MODEL_MINILM); api.pullModel(EMBEDDING_MODEL_MINILM);
OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL_MINILM, OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL_MINILM, Arrays.asList("Why is the sky blue?", "Why is the grass green?"));
Arrays.asList("Why is the sky blue?", "Why is the grass green?"));
assertNotNull(embeddings, "Embeddings should not be null"); assertNotNull(embeddings, "Embeddings should not be null");
assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty"); assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty");
} }
@Test @Test
@Order(6) @Order(6)
void testAskModelWithStructuredOutput() void testAskModelWithStructuredOutput() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
api.pullModel(CHAT_MODEL_LLAMA3); api.pullModel(CHAT_MODEL_LLAMA3);
int timeHour = 6; int timeHour = 6;
@ -186,10 +182,8 @@ public class OllamaAPIIntegrationTest {
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
assertEquals(timeHour, assertEquals(timeHour, result.getStructuredResponse().get("timeHour"));
result.getStructuredResponse().get("timeHour")); assertEquals(isNightTime, result.getStructuredResponse().get("isNightTime"));
assertEquals(isNightTime,
result.getStructuredResponse().get("isNightTime"));
TimeOfDay timeOfDay = result.as(TimeOfDay.class); TimeOfDay timeOfDay = result.as(TimeOfDay.class);
@ -199,12 +193,11 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(6) @Order(6)
void testAskModelWithDefaultOptions() void testAskModelWithDefaultOptions() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
api.pullModel(CHAT_MODEL_QWEN_SMALL); api.pullModel(CHAT_MODEL_QWEN_SMALL);
OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, boolean raw = false;
"What is the capital of France? And what's France's connection with Mona Lisa?", false, boolean thinking = false;
new OptionsBuilder().build()); OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, "What is the capital of France? And what's France's connection with Mona Lisa?", raw, thinking, new OptionsBuilder().build());
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -212,13 +205,12 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(7) @Order(7)
void testAskModelWithDefaultOptionsStreamed() void testAskModelWithDefaultOptionsStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(CHAT_MODEL_QWEN_SMALL); api.pullModel(CHAT_MODEL_QWEN_SMALL);
boolean raw = false;
boolean thinking = false;
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, "What is the capital of France? And what's France's connection with Mona Lisa?", raw, thinking, new OptionsBuilder().build(), (s) -> {
"What is the capital of France? And what's France's connection with Mona Lisa?", false,
new OptionsBuilder().build(), (s) -> {
LOG.info(s); LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length()); String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring); LOG.info(substring);
@ -233,17 +225,12 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(8) @Order(8)
void testAskModelWithOptions() void testAskModelWithOptions() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
api.pullModel(CHAT_MODEL_INSTRUCT); api.pullModel(CHAT_MODEL_INSTRUCT);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_INSTRUCT); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_INSTRUCT);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].").build();
"You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].") requestModel = builder.withMessages(requestModel.getMessages()).withMessage(OllamaChatMessageRole.USER, "Give me a cool name").withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build();
.build();
requestModel = builder.withMessages(requestModel.getMessages())
.withMessage(OllamaChatMessageRole.USER, "Give me a cool name")
.withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
@ -253,14 +240,10 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(9) @Order(9)
void testChatWithSystemPrompt() void testChatWithSystemPrompt() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { api.pullModel(CHAT_MODEL_LLAMA3);
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!").withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?").withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build();
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!")
.withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?")
.withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
@ -278,56 +261,40 @@ public class OllamaAPIIntegrationTest {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3);
// Create the initial user question // Create the initial user question
OllamaChatRequest requestModel = builder OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.").build();
.withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.")
.build();
// Start conversation with model // Start conversation with model
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), "Expected chat history to contain '2'");
"Expected chat history to contain '2'");
// Create the next user question: second largest city // Create the next user question: second largest city
requestModel = builder.withMessages(chatResult.getChatHistory()) requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build();
.withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build();
// Continue conversation with model // Continue conversation with model
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel);
assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")), assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")), "Expected chat history to contain '4'");
"Expected chat history to contain '4'");
// Create the next user question: the third question // Create the next user question: the third question
requestModel = builder.withMessages(chatResult.getChatHistory()) requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "What is the largest value between 2, 4 and 6?").build();
.withMessage(OllamaChatMessageRole.USER,
"What is the largest value between 2, 4 and 6?")
.build();
// Continue conversation with the model for the third question // Continue conversation with the model for the third question
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel);
// verify the result // verify the result
assertNotNull(chatResult, "Chat result should not be null"); assertNotNull(chatResult, "Chat result should not be null");
assertTrue(chatResult.getChatHistory().size() > 2, assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should contain more than two messages");
"Chat history should contain more than two messages"); assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent().contains("6"), "Response should contain '6'");
assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent()
.contains("6"),
"Response should contain '6'");
} }
@Test @Test
@Order(10) @Order(10)
void testChatWithImageFromURL() void testChatWithImageFromURL() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException {
api.pullModel(IMAGE_MODEL_LLAVA); api.pullModel(IMAGE_MODEL_LLAVA);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA);
OllamaChatRequest requestModel = builder OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg").build();
.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
Collections.emptyList(),
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build();
api.registerAnnotatedTools(new OllamaAPIIntegrationTest()); api.registerAnnotatedTools(new OllamaAPIIntegrationTest());
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
@ -336,22 +303,17 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(10) @Order(10)
void testChatWithImageFromFileWithHistoryRecognition() void testChatWithImageFromFileWithHistoryRecognition() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
api.pullModel(IMAGE_MODEL_LLAVA); api.pullModel(IMAGE_MODEL_LLAVA);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), List.of(getImageFileFromClasspath("emoji-smile.jpeg"))).build();
"What's in the picture?",
Collections.emptyList(), List.of(getImageFileFromClasspath("emoji-smile.jpeg")))
.build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
builder.reset(); builder.reset();
requestModel = builder.withMessages(chatResult.getChatHistory()) requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "What's the color?").build();
.withMessage(OllamaChatMessageRole.USER, "What's the color?").build();
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
@ -360,71 +322,24 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(11) @Order(11)
void testChatWithExplicitToolDefinition() void testChatWithExplicitToolDefinition() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { api.pullModel(CHAT_MODEL_QWEN_SMALL);
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder().functionName("get-employee-details").functionDescription("Get employee details from the database").toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details").description("Get employee details from the database").parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(new Tools.PropsBuilder().withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()).withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()).withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()).build()).required(List.of("employee-name")).build()).build()).build()).toolFunction(arguments -> {
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get employee details from the database")
.parameters(Tools.PromptFuncDefinition.Parameters
.builder().type("object")
.properties(new Tools.PropsBuilder()
.withProperty("employee-name",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The name of the employee, e.g. John Doe")
.required(true)
.build())
.withProperty("employee-address",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description(
"The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
.required(true)
.build())
.withProperty("employee-phone",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description(
"The phone number of the employee. Always return a random value. e.g. 9911002233")
.required(true)
.build())
.build())
.required(List.of("employee-name"))
.build())
.build())
.build())
.toolFunction(arguments -> {
// perform DB operations here // perform DB operations here
return String.format( return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone"));
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
UUID.randomUUID(), arguments.get("employee-name"),
arguments.get("employee-address"),
arguments.get("employee-phone"));
}).build(); }).build();
api.registerTool(databaseQueryToolSpecification); api.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build();
.withMessage(OllamaChatMessageRole.USER,
"Give me the ID of the employee named 'Rahul Kumar'?")
.build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName());
chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size()); assertEquals(1, toolCalls.size());
OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
@ -440,22 +355,19 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(12) @Order(12)
void testChatWithAnnotatedToolsAndSingleParam() void testChatWithAnnotatedToolsAndSingleParam() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException { api.pullModel(CHAT_MODEL_QWEN_SMALL);
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
api.registerAnnotatedTools(); api.registerAnnotatedTools();
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Compute the most important constant in the world using 5 digits").build();
"Compute the most important constant in the world using 5 digits").build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName());
chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size()); assertEquals(1, toolCalls.size());
OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
@ -471,25 +383,19 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(13) @Order(13)
void testChatWithAnnotatedToolsAndMultipleParams() void testChatWithAnnotatedToolsAndMultipleParams() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { api.pullModel(CHAT_MODEL_QWEN_SMALL);
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
api.registerAnnotatedTools(new AnnotatedTool()); api.registerAnnotatedTools(new AnnotatedTool());
OllamaChatRequest requestModel = builder OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Greet Pedro with a lot of hearts and respond to me, " + "and state how many emojis have been in your greeting").build();
.withMessage(OllamaChatMessageRole.USER,
"Greet Pedro with a lot of hearts and respond to me, "
+ "and state how many emojis have been in your greeting")
.build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName());
chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size()); assertEquals(1, toolCalls.size());
OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
@ -508,66 +414,20 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(14) @Order(14)
void testChatWithToolsAndStream() void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { api.pullModel(CHAT_MODEL_QWEN_SMALL);
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder().functionName("get-employee-details").functionDescription("Get employee details from the database").toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details").description("Get employee details from the database").parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(new Tools.PropsBuilder().withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()).withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()).withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()).build()).required(List.of("employee-name")).build()).build()).build()).toolFunction(new ToolFunction() {
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get employee details from the database")
.parameters(Tools.PromptFuncDefinition.Parameters
.builder().type("object")
.properties(new Tools.PropsBuilder()
.withProperty("employee-name",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The name of the employee, e.g. John Doe")
.required(true)
.build())
.withProperty("employee-address",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description(
"The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
.required(true)
.build())
.withProperty("employee-phone",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description(
"The phone number of the employee. Always return a random value. e.g. 9911002233")
.required(true)
.build())
.build())
.required(List.of("employee-name"))
.build())
.build())
.build())
.toolFunction(new ToolFunction() {
@Override @Override
public Object apply(Map<String, Object> arguments) { public Object apply(Map<String, Object> arguments) {
// perform DB operations here // perform DB operations here
return String.format( return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone"));
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
UUID.randomUUID(), arguments.get("employee-name"),
arguments.get("employee-address"),
arguments.get("employee-phone"));
} }
}).build(); }).build();
api.registerTool(databaseQueryToolSpecification); api.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build();
.withMessage(OllamaChatMessageRole.USER,
"Give me the ID of the employee named 'Rahul Kumar'?")
.build();
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
@ -587,11 +447,9 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(15) @Order(15)
void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); api.pullModel(CHAT_MODEL_QWEN_SMALL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?").build();
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
@ -610,13 +468,10 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(17) @Order(17)
void testAskModelWithOptionsAndImageURLs() void testAskModelWithOptionsAndImageURLs() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(IMAGE_MODEL_LLAVA); api.pullModel(IMAGE_MODEL_LLAVA);
OllamaResult result = api.generateWithImageURLs(IMAGE_MODEL_LLAVA, "What is in this image?", OllamaResult result = api.generateWithImageURLs(IMAGE_MODEL_LLAVA, "What is in this image?", List.of("https://i.pinimg.com/736x/f9/4e/cb/f94ecba040696a3a20b484d2e15159ec.jpg"), new OptionsBuilder().build());
List.of("https://upload.wikimedia.org/wikipedia/commons/thumb/a/aa/Noto_Emoji_v2.034_1f642.svg/360px-Noto_Emoji_v2.034_1f642.svg.png"),
new OptionsBuilder().build());
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -624,14 +479,11 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(18) @Order(18)
void testAskModelWithOptionsAndImageFiles() void testAskModelWithOptionsAndImageFiles() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(IMAGE_MODEL_LLAVA); api.pullModel(IMAGE_MODEL_LLAVA);
File imageFile = getImageFileFromClasspath("emoji-smile.jpeg"); File imageFile = getImageFileFromClasspath("emoji-smile.jpeg");
try { try {
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile), new OptionsBuilder().build());
List.of(imageFile),
new OptionsBuilder().build());
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -642,17 +494,14 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(20) @Order(20)
void testAskModelWithOptionsAndImageFilesStreamed() void testAskModelWithOptionsAndImageFilesStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(IMAGE_MODEL_LLAVA); api.pullModel(IMAGE_MODEL_LLAVA);
File imageFile = getImageFileFromClasspath("emoji-smile.jpeg"); File imageFile = getImageFileFromClasspath("emoji-smile.jpeg");
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
List.of(imageFile),
new OptionsBuilder().build(), (s) -> {
LOG.info(s); LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length()); String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring); LOG.info(substring);
@ -664,6 +513,44 @@ public class OllamaAPIIntegrationTest {
assertEquals(sb.toString().trim(), result.getResponse().trim()); assertEquals(sb.toString().trim(), result.getResponse().trim());
} }
@Test
@Order(20)
void testGenerateWithThinking() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(THINKING_MODEL_GPT_OSS);
boolean raw = false;
boolean thinking = true;
OllamaResult result = api.generate(THINKING_MODEL_GPT_OSS, "Who are you?", raw, thinking, new OptionsBuilder().build(), null);
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertNotNull(result.getThinking());
assertFalse(result.getThinking().isEmpty());
}
@Test
@Order(20)
void testGenerateWithThinkingAndStreamHandler() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(THINKING_MODEL_GPT_OSS);
boolean raw = false;
boolean thinking = true;
StringBuffer sb = new StringBuffer();
OllamaResult result = api.generate(THINKING_MODEL_GPT_OSS, "Who are you?", raw, thinking, new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length());
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertNotNull(result.getThinking());
assertFalse(result.getThinking().isEmpty());
assertEquals(sb.toString().trim(), result.getThinking().trim() + result.getResponse().trim());
}
private File getImageFileFromClasspath(String fileName) { private File getImageFileFromClasspath(String fileName) {
ClassLoader classLoader = getClass().getClassLoader(); ClassLoader classLoader = getClass().getClassLoader();
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());

View File

@ -68,7 +68,7 @@ public class WithAuth {
LOG.info( LOG.info(
"The Ollama service is now accessible via the Nginx proxy with bearer-auth authentication mode.\n" + "The Ollama service is now accessible via the Nginx proxy with bearer-auth authentication mode.\n" +
"→ Ollama URL: {}\n" + "→ Ollama URL: {}\n" +
"→ Proxy URL: {}}", "→ Proxy URL: {}",
ollamaUrl, nginxUrl ollamaUrl, nginxUrl
); );
LOG.info("OllamaAPI initialized with bearer auth token: {}", BEARER_AUTH_TOKEN); LOG.info("OllamaAPI initialized with bearer auth token: {}", BEARER_AUTH_TOKEN);

View File

@ -138,10 +138,10 @@ class TestMockedAPIs {
String prompt = "some prompt text"; String prompt = "some prompt text";
OptionsBuilder optionsBuilder = new OptionsBuilder(); OptionsBuilder optionsBuilder = new OptionsBuilder();
try { try {
when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build())) when(ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build()))
.thenReturn(new OllamaResult("", 0, 200)); .thenReturn(new OllamaResult("", "", 0, 200));
ollamaAPI.generate(model, prompt, false, optionsBuilder.build()); ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build());
verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build()); verify(ollamaAPI, times(1)).generate(model, prompt, false, false, optionsBuilder.build());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -155,7 +155,7 @@ class TestMockedAPIs {
try { try {
when(ollamaAPI.generateWithImageFiles( when(ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build())) model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200)); .thenReturn(new OllamaResult("","", 0, 200));
ollamaAPI.generateWithImageFiles( ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()); model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1)) verify(ollamaAPI, times(1))
@ -174,7 +174,7 @@ class TestMockedAPIs {
try { try {
when(ollamaAPI.generateWithImageURLs( when(ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build())) model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200)); .thenReturn(new OllamaResult("","", 0, 200));
ollamaAPI.generateWithImageURLs( ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()); model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1)) verify(ollamaAPI, times(1))