diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 5689faa..be91603 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -749,6 +749,8 @@ public class OllamaAPI { * * @param model the ollama model to ask the question to * @param prompt the prompt/question text + * @param raw if true no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API + * @param think if true the model will "think" step-by-step before generating the final response * @param options the Options object - More @@ -761,14 +763,42 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generate(String model, String prompt, boolean raw, Options options, + public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); ollamaRequestModel.setRaw(raw); + ollamaRequestModel.setThink(think); ollamaRequestModel.setOptions(options.getOptionsMap()); return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); } + /** + * Generates response using the specified AI model and prompt (in blocking + * mode). + *

+ * Uses {@link #generate(String, String, boolean, boolean, Options, OllamaStreamHandler)} + * + * @param model The name or identifier of the AI model to use for generating + * the response. + * @param prompt The input text or prompt to provide to the AI model. + * @param raw In some cases, you may wish to bypass the templating system + * and provide a full prompt. In this case, you can use the raw + * parameter to disable templating. Also note that raw mode will + * not return a context. + * @param think If set to true, the model will "think" step-by-step before + * generating the final response. + * @param options Additional options or configurations to use when generating + * the response. + * @return {@link OllamaResult} + * @throws OllamaBaseException if the response indicates an error status + * @throws IOException if an I/O error occurs during the HTTP request + * @throws InterruptedException if the operation is interrupted + */ + public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options) + throws OllamaBaseException, IOException, InterruptedException { + return generate(model, prompt, raw, think, options, null); + } + /** * Generates structured output from the specified AI model and prompt. * @@ -809,7 +839,7 @@ public class OllamaAPI { if (statusCode == 200) { OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, OllamaStructuredResult.class); - OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), + OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(), structuredResult.getResponseTime(), statusCode); return ollamaResult; } else { @@ -817,31 +847,6 @@ public class OllamaAPI { } } - /** - * Generates response using the specified AI model and prompt (in blocking - * mode). - *

- * Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)} - * - * @param model The name or identifier of the AI model to use for generating - * the response. - * @param prompt The input text or prompt to provide to the AI model. - * @param raw In some cases, you may wish to bypass the templating system - * and provide a full prompt. In this case, you can use the raw - * parameter to disable templating. Also note that raw mode will - * not return a context. - * @param options Additional options or configurations to use when generating - * the response. - * @return {@link OllamaResult} - * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request - * @throws InterruptedException if the operation is interrupted - */ - public OllamaResult generate(String model, String prompt, boolean raw, Options options) - throws OllamaBaseException, IOException, InterruptedException { - return generate(model, prompt, raw, options, null); - } - /** * Generates response using the specified AI model and prompt (in blocking * mode), and then invokes a set of tools @@ -850,6 +855,8 @@ public class OllamaAPI { * @param model The name or identifier of the AI model to use for generating * the response. * @param prompt The input text or prompt to provide to the AI model. + * @param think If set to true, the model will "think" step-by-step before + * generating the final response. * @param options Additional options or configurations to use when generating * the response. * @return {@link OllamaToolsResult} An OllamaToolsResult object containing the @@ -859,7 +866,7 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaToolsResult generateWithTools(String model, String prompt, Options options) + public OllamaToolsResult generateWithTools(String model, String prompt, boolean think, Options options) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { boolean raw = true; OllamaToolsResult toolResult = new OllamaToolsResult(); @@ -874,7 +881,7 @@ public class OllamaAPI { prompt = promptBuilder.build(); } - OllamaResult result = generate(model, prompt, raw, options, null); + OllamaResult result = generate(model, prompt, raw, think, options, null); toolResult.setModelResult(result); String toolsResponse = result.getResponse(); @@ -1023,19 +1030,25 @@ public class OllamaAPI { /** * Synchronously generates a response using a list of image byte arrays. *

- * This method encodes the provided byte arrays into Base64 and sends them to the Ollama server. + * This method encodes the provided byte arrays into Base64 and sends them to + * the Ollama server. * * @param model the Ollama model to use for generating the response * @param prompt the prompt or question text to send to the model * @param images the list of image data as byte arrays - * @param options the Options object - More details on the options - * @param streamHandler optional callback that will be invoked with each streamed response; if null, streaming is disabled - * @return OllamaResult containing the response text and the time taken for the response + * @param options the Options object - More + * details on the options + * @param streamHandler optional callback that will be invoked with each + * streamed response; if null, streaming is disabled + * @return OllamaResult containing the response text and the time taken for the + * response * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImages(String model, String prompt, List images, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImages(String model, String prompt, List images, Options options, + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { List encodedImages = new ArrayList<>(); for (byte[] image : images) { encodedImages.add(encodeByteArrayToBase64(image)); @@ -1046,15 +1059,18 @@ public class OllamaAPI { } /** - * Convenience method to call the Ollama API using image byte arrays without streaming responses. + * Convenience method to call the Ollama API using image byte arrays without + * streaming responses. *

- * Uses {@link #generateWithImages(String, String, List, Options, OllamaStreamHandler)} + * Uses + * {@link #generateWithImages(String, String, List, Options, OllamaStreamHandler)} * * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImages(String model, String prompt, List images, Options options) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImages(String model, String prompt, List images, Options options) + throws OllamaBaseException, IOException, InterruptedException { return generateWithImages(model, prompt, images, options, null); } @@ -1069,10 +1085,12 @@ public class OllamaAPI { * history including the newly acquired assistant response. * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network + * @throws InterruptedException in case the server is not reachable or + * network * issues happen * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request + * @throws IOException if an I/O error occurs during the HTTP + * request * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ @@ -1092,10 +1110,12 @@ public class OllamaAPI { * @return {@link OllamaChatResult} * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network + * @throws InterruptedException in case the server is not reachable or + * network * issues happen * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request + * @throws IOException if an I/O error occurs during the HTTP + * request * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ @@ -1117,10 +1137,12 @@ public class OllamaAPI { * @return {@link OllamaChatResult} * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network + * @throws InterruptedException in case the server is not reachable or + * network * issues happen * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request + * @throws IOException if an I/O error occurs during the HTTP + * request * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java index 9094546..47d6eb5 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java @@ -39,11 +39,17 @@ public class OllamaChatRequestBuilder { request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); } - public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){ - return withMessage(role,content, Collections.emptyList()); + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content) { + return withMessage(role, content, Collections.emptyList()); } - public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls,List images) { + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls) { + List messages = this.request.getMessages(); + messages.add(new OllamaChatMessage(role, content, toolCalls, null)); + return this; + } + + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls, List images) { List messages = this.request.getMessages(); List binaryImages = images.stream().map(file -> { @@ -55,11 +61,11 @@ public class OllamaChatRequestBuilder { } }).collect(Collectors.toList()); - messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages)); + messages.add(new OllamaChatMessage(role, content, toolCalls, binaryImages)); return this; } - public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List toolCalls, String... imageUrls) { + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls, String... imageUrls) { List messages = this.request.getMessages(); List binaryImages = null; if (imageUrls.length > 0) { @@ -75,7 +81,7 @@ public class OllamaChatRequestBuilder { } } - messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages)); + messages.add(new OllamaChatMessage(role, content, toolCalls, binaryImages)); return this; } diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateRequest.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateRequest.java index de767dc..bb37a4c 100644 --- a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateRequest.java +++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateRequest.java @@ -19,6 +19,7 @@ public class OllamaGenerateRequest extends OllamaCommonRequest implements Ollama private String system; private String context; private boolean raw; + private boolean think; public OllamaGenerateRequest() { } diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java index 9fb975e..0d4c749 100644 --- a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java +++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java @@ -12,6 +12,7 @@ public class OllamaGenerateResponseModel { private String model; private @JsonProperty("created_at") String createdAt; private String response; + private String thinking; private boolean done; private List context; private @JsonProperty("total_duration") Long totalDuration; diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java index bc47fa0..a13a0a0 100644 --- a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java +++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java @@ -21,9 +21,17 @@ public class OllamaGenerateStreamObserver { } protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart) { - message = message + currentResponsePart.getResponse(); + String response = currentResponsePart.getResponse(); + String thinking = currentResponsePart.getThinking(); + + boolean hasResponse = response != null && !response.trim().isEmpty(); + boolean hasThinking = thinking != null && !thinking.trim().isEmpty(); + + if (!hasResponse && hasThinking) { + message = message + thinking; + } else if (hasResponse) { + message = message + response; + } streamHandler.accept(message); } - - } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java index 09a3870..94db829 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java @@ -46,18 +46,18 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { * in case the JSON Object cannot be parsed to a {@link OllamaChatResponseModel}. Thus, the ResponseModel should * never be null. * - * @param line streamed line of ollama stream response + * @param line streamed line of ollama stream response * @param responseBuffer Stringbuffer to add latest response message part to * @return TRUE, if ollama-Response has 'done' state */ @Override - protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { + protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) { try { OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); // it seems that under heavy load ollama responds with an empty chat message part in the streamed response // thus, we null check the message and hope that the next streamed response has some message content again OllamaChatMessage message = ollamaResponseModel.getMessage(); - if(message != null) { + if (message != null) { responseBuffer.append(message.getContent()); if (tokenHandler != null) { tokenHandler.accept(ollamaResponseModel); @@ -92,6 +92,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { int statusCode = response.statusCode(); InputStream responseBodyStream = response.body(); StringBuilder responseBuffer = new StringBuilder(); + StringBuilder thinkingBuffer = new StringBuilder(); OllamaChatResponseModel ollamaChatResponseModel = null; List wantedToolsForStream = null; try (BufferedReader reader = @@ -115,10 +116,15 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); responseBuffer.append(ollamaResponseModel.getError()); + } else if (statusCode == 500) { + LOG.warn("Status code: 500 (Internal Server Error)"); + OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, + OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); } else { - boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); - ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); - if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){ + boolean finished = parseResponseAndAddToBuffer(line, responseBuffer, thinkingBuffer); + ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); + if (body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null) { wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls(); } if (finished && body.stream) { @@ -132,11 +138,11 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { LOG.error("Status code " + statusCode); throw new OllamaBaseException(responseBuffer.toString()); } else { - if(wantedToolsForStream != null) { + if (wantedToolsForStream != null) { ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream); } OllamaChatResult ollamaResult = - new OllamaChatResult(ollamaChatResponseModel,body.getMessages()); + new OllamaChatResult(ollamaChatResponseModel, body.getMessages()); if (isVerbose()) LOG.info("Model response: " + ollamaResult); return ollamaResult; } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java index 1f42ef8..04d7fd9 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java @@ -32,7 +32,7 @@ public abstract class OllamaEndpointCaller { protected abstract String getEndpointSuffix(); - protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); + protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer); /** diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java index 461ec75..5e7c1f4 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java @@ -38,10 +38,15 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { } @Override - protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { + protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) { try { OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); - responseBuffer.append(ollamaResponseModel.getResponse()); + if (ollamaResponseModel.getResponse() != null) { + responseBuffer.append(ollamaResponseModel.getResponse()); + } + if (ollamaResponseModel.getThinking() != null) { + thinkingBuffer.append(ollamaResponseModel.getThinking()); + } if (streamObserver != null) { streamObserver.notify(ollamaResponseModel); } @@ -84,6 +89,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { int statusCode = response.statusCode(); InputStream responseBodyStream = response.body(); StringBuilder responseBuffer = new StringBuilder(); + StringBuilder thinkingBuffer = new StringBuilder(); try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { String line; @@ -105,7 +111,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { OllamaErrorResponse.class); responseBuffer.append(ollamaResponseModel.getError()); } else { - boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); + boolean finished = parseResponseAndAddToBuffer(line, responseBuffer, thinkingBuffer); if (finished) { break; } @@ -119,7 +125,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { } else { long endTime = System.currentTimeMillis(); OllamaResult ollamaResult = - new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); + new OllamaResult(responseBuffer.toString().trim(), thinkingBuffer.toString().trim(), endTime - startTime, statusCode); if (isVerbose()) LOG.info("Model response: " + ollamaResult); return ollamaResult; } diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java index 4b538f9..fcf7442 100644 --- a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java +++ b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java @@ -12,107 +12,112 @@ import static io.github.ollama4j.utils.Utils.getObjectMapper; import java.util.HashMap; import java.util.Map; -/** The type Ollama result. */ +/** + * The type Ollama result. + */ @Getter @SuppressWarnings("unused") @Data @JsonIgnoreProperties(ignoreUnknown = true) public class OllamaResult { - /** - * -- GETTER -- - * Get the completion/response text - * - * @return String completion/response text - */ - private final String response; + /** + * -- GETTER -- + * Get the completion/response text + * + * @return String completion/response text + */ + private final String response; + private final String thinking; - /** - * -- GETTER -- - * Get the response status code. - * - * @return int - response status code - */ - private int httpStatusCode; + /** + * -- GETTER -- + * Get the response status code. + * + * @return int - response status code + */ + private int httpStatusCode; - /** - * -- GETTER -- - * Get the response time in milliseconds. - * - * @return long - response time in milliseconds - */ - private long responseTime = 0; + /** + * -- GETTER -- + * Get the response time in milliseconds. + * + * @return long - response time in milliseconds + */ + private long responseTime = 0; - public OllamaResult(String response, long responseTime, int httpStatusCode) { - this.response = response; - this.responseTime = responseTime; - this.httpStatusCode = httpStatusCode; - } - - @Override - public String toString() { - try { - Map responseMap = new HashMap<>(); - responseMap.put("response", this.response); - responseMap.put("httpStatusCode", this.httpStatusCode); - responseMap.put("responseTime", this.responseTime); - return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - /** - * Get the structured response if the response is a JSON object. - * - * @return Map - structured response - * @throws IllegalArgumentException if the response is not a valid JSON object - */ - public Map getStructuredResponse() { - String responseStr = this.getResponse(); - if (responseStr == null || responseStr.trim().isEmpty()) { - throw new IllegalArgumentException("Response is empty or null"); + public OllamaResult(String response, String thinking, long responseTime, int httpStatusCode) { + this.response = response; + this.thinking = thinking; + this.responseTime = responseTime; + this.httpStatusCode = httpStatusCode; } - try { - // Check if the response is a valid JSON - if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || - (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { - throw new IllegalArgumentException("Response is not a valid JSON object"); - } - - Map response = getObjectMapper().readValue(responseStr, - new TypeReference>() { - }); - return response; - } catch (JsonProcessingException e) { - throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); - } - } - - /** - * Get the structured response mapped to a specific class type. - * - * @param The type of class to map the response to - * @param clazz The class to map the response to - * @return An instance of the specified class with the response data - * @throws IllegalArgumentException if the response is not a valid JSON or is empty - * @throws RuntimeException if there is an error mapping the response - */ - public T as(Class clazz) { - String responseStr = this.getResponse(); - if (responseStr == null || responseStr.trim().isEmpty()) { - throw new IllegalArgumentException("Response is empty or null"); + @Override + public String toString() { + try { + Map responseMap = new HashMap<>(); + responseMap.put("response", this.response); + responseMap.put("thinking", this.thinking); + responseMap.put("httpStatusCode", this.httpStatusCode); + responseMap.put("responseTime", this.responseTime); + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } } - try { - // Check if the response is a valid JSON - if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || - (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { - throw new IllegalArgumentException("Response is not a valid JSON object"); - } - return getObjectMapper().readValue(responseStr, clazz); - } catch (JsonProcessingException e) { - throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); + /** + * Get the structured response if the response is a JSON object. + * + * @return Map - structured response + * @throws IllegalArgumentException if the response is not a valid JSON object + */ + public Map getStructuredResponse() { + String responseStr = this.getResponse(); + if (responseStr == null || responseStr.trim().isEmpty()) { + throw new IllegalArgumentException("Response is empty or null"); + } + + try { + // Check if the response is a valid JSON + if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || + (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { + throw new IllegalArgumentException("Response is not a valid JSON object"); + } + + Map response = getObjectMapper().readValue(responseStr, + new TypeReference>() { + }); + return response; + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); + } + } + + /** + * Get the structured response mapped to a specific class type. + * + * @param The type of class to map the response to + * @param clazz The class to map the response to + * @return An instance of the specified class with the response data + * @throws IllegalArgumentException if the response is not a valid JSON or is empty + * @throws RuntimeException if there is an error mapping the response + */ + public T as(Class clazz) { + String responseStr = this.getResponse(); + if (responseStr == null || responseStr.trim().isEmpty()) { + throw new IllegalArgumentException("Response is empty or null"); + } + + try { + // Check if the response is a valid JSON + if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || + (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { + throw new IllegalArgumentException("Response is not a valid JSON object"); + } + return getObjectMapper().readValue(responseStr, clazz); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); + } } - } } diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java index 9ae3e71..aaa98d3 100644 --- a/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java +++ b/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java @@ -21,6 +21,7 @@ import lombok.NoArgsConstructor; @JsonIgnoreProperties(ignoreUnknown = true) public class OllamaStructuredResult { private String response; + private String thinking; private int httpStatusCode; diff --git a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java index abe388c..f81b45b 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java +++ b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java @@ -52,18 +52,19 @@ public class OllamaAPIIntegrationTest { private static final String CHAT_MODEL_SYSTEM_PROMPT = "llama3.2:1b"; private static final String CHAT_MODEL_LLAMA3 = "llama3"; private static final String IMAGE_MODEL_LLAVA = "llava"; + private static final String THINKING_MODEL_GPT_OSS = "gpt-oss:20b"; @BeforeAll public static void setUp() { try { boolean useExternalOllamaHost = Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST")); String ollamaHost = System.getenv("OLLAMA_HOST"); + if (useExternalOllamaHost) { LOG.info("Using external Ollama host..."); api = new OllamaAPI(ollamaHost); } else { - throw new RuntimeException( - "USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers Ollama host for the tests now. If you would like to use an external host, please set the env var to USE_EXTERNAL_OLLAMA_HOST=true and set the env var OLLAMA_HOST=http://localhost:11435 or a different host/port."); + throw new RuntimeException("USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers Ollama host for the tests now. If you would like to use an external host, please set the env var to USE_EXTERNAL_OLLAMA_HOST=true and set the env var OLLAMA_HOST=http://localhost:11435 or a different host/port."); } } catch (Exception e) { String ollamaVersion = "0.6.1"; @@ -102,8 +103,7 @@ public class OllamaAPIIntegrationTest { @Test @Order(2) - public void testListModelsAPI() - throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + public void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { api.pullModel(EMBEDDING_MODEL_MINILM); // Fetch the list of models List models = api.listModels(); @@ -115,8 +115,7 @@ public class OllamaAPIIntegrationTest { @Test @Order(2) - void testListModelsFromLibrary() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testListModelsFromLibrary() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { List models = api.listModelsFromLibrary(); assertNotNull(models); assertFalse(models.isEmpty()); @@ -124,8 +123,7 @@ public class OllamaAPIIntegrationTest { @Test @Order(3) - public void testPullModelAPI() - throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + public void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { api.pullModel(EMBEDDING_MODEL_MINILM); List models = api.listModels(); assertNotNull(models, "Models should not be null"); @@ -145,16 +143,14 @@ public class OllamaAPIIntegrationTest { @Order(5) public void testEmbeddings() throws Exception { api.pullModel(EMBEDDING_MODEL_MINILM); - OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL_MINILM, - Arrays.asList("Why is the sky blue?", "Why is the grass green?")); + OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL_MINILM, Arrays.asList("Why is the sky blue?", "Why is the grass green?")); assertNotNull(embeddings, "Embeddings should not be null"); assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty"); } @Test @Order(6) - void testAskModelWithStructuredOutput() - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + void testAskModelWithStructuredOutput() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { api.pullModel(CHAT_MODEL_LLAMA3); int timeHour = 6; @@ -186,10 +182,8 @@ public class OllamaAPIIntegrationTest { assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); - assertEquals(timeHour, - result.getStructuredResponse().get("timeHour")); - assertEquals(isNightTime, - result.getStructuredResponse().get("isNightTime")); + assertEquals(timeHour, result.getStructuredResponse().get("timeHour")); + assertEquals(isNightTime, result.getStructuredResponse().get("isNightTime")); TimeOfDay timeOfDay = result.as(TimeOfDay.class); @@ -199,12 +193,11 @@ public class OllamaAPIIntegrationTest { @Test @Order(6) - void testAskModelWithDefaultOptions() - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + void testAskModelWithDefaultOptions() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { api.pullModel(CHAT_MODEL_QWEN_SMALL); - OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, - "What is the capital of France? And what's France's connection with Mona Lisa?", false, - new OptionsBuilder().build()); + boolean raw = false; + boolean thinking = false; + OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, "What is the capital of France? And what's France's connection with Mona Lisa?", raw, thinking, new OptionsBuilder().build()); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); @@ -212,18 +205,17 @@ public class OllamaAPIIntegrationTest { @Test @Order(7) - void testAskModelWithDefaultOptionsStreamed() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testAskModelWithDefaultOptionsStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { api.pullModel(CHAT_MODEL_QWEN_SMALL); + boolean raw = false; + boolean thinking = false; StringBuffer sb = new StringBuffer(); - OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, - "What is the capital of France? And what's France's connection with Mona Lisa?", false, - new OptionsBuilder().build(), (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); + OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, "What is the capital of France? And what's France's connection with Mona Lisa?", raw, thinking, new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); assertNotNull(result); assertNotNull(result.getResponse()); @@ -233,17 +225,12 @@ public class OllamaAPIIntegrationTest { @Test @Order(8) - void testAskModelWithOptions() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { + void testAskModelWithOptions() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { api.pullModel(CHAT_MODEL_INSTRUCT); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_INSTRUCT); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, - "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].") - .build(); - requestModel = builder.withMessages(requestModel.getMessages()) - .withMessage(OllamaChatMessageRole.USER, "Give me a cool name") - .withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].").build(); + requestModel = builder.withMessages(requestModel.getMessages()).withMessage(OllamaChatMessageRole.USER, "Give me a cool name").withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); @@ -253,14 +240,10 @@ public class OllamaAPIIntegrationTest { @Test @Order(9) - void testChatWithSystemPrompt() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, - "You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!") - .withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?") - .withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build(); + void testChatWithSystemPrompt() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { + api.pullModel(CHAT_MODEL_LLAMA3); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!").withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?").withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); @@ -278,56 +261,40 @@ public class OllamaAPIIntegrationTest { OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3); // Create the initial user question - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.") - .build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.").build(); // Start conversation with model OllamaChatResult chatResult = api.chat(requestModel); - assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), - "Expected chat history to contain '2'"); + assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), "Expected chat history to contain '2'"); // Create the next user question: second largest city - requestModel = builder.withMessages(chatResult.getChatHistory()) - .withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); + requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); // Continue conversation with model chatResult = api.chat(requestModel); - assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")), - "Expected chat history to contain '4'"); + assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")), "Expected chat history to contain '4'"); // Create the next user question: the third question - requestModel = builder.withMessages(chatResult.getChatHistory()) - .withMessage(OllamaChatMessageRole.USER, - "What is the largest value between 2, 4 and 6?") - .build(); + requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "What is the largest value between 2, 4 and 6?").build(); // Continue conversation with the model for the third question chatResult = api.chat(requestModel); // verify the result assertNotNull(chatResult, "Chat result should not be null"); - assertTrue(chatResult.getChatHistory().size() > 2, - "Chat history should contain more than two messages"); - assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent() - .contains("6"), - "Response should contain '6'"); + assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should contain more than two messages"); + assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent().contains("6"), "Response should contain '6'"); } @Test @Order(10) - void testChatWithImageFromURL() - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException { + void testChatWithImageFromURL() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException { api.pullModel(IMAGE_MODEL_LLAVA); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, "What's in the picture?", - Collections.emptyList(), - "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") - .build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg").build(); api.registerAnnotatedTools(new OllamaAPIIntegrationTest()); OllamaChatResult chatResult = api.chat(requestModel); @@ -336,22 +303,17 @@ public class OllamaAPIIntegrationTest { @Test @Order(10) - void testChatWithImageFromFileWithHistoryRecognition() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { + void testChatWithImageFromFileWithHistoryRecognition() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { api.pullModel(IMAGE_MODEL_LLAVA); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "What's in the picture?", - Collections.emptyList(), List.of(getImageFileFromClasspath("emoji-smile.jpeg"))) - .build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), List.of(getImageFileFromClasspath("emoji-smile.jpeg"))).build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); builder.reset(); - requestModel = builder.withMessages(chatResult.getChatHistory()) - .withMessage(OllamaChatMessageRole.USER, "What's the color?").build(); + requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "What's the color?").build(); chatResult = api.chat(requestModel); assertNotNull(chatResult); @@ -360,71 +322,24 @@ public class OllamaAPIIntegrationTest { @Test @Order(11) - void testChatWithExplicitToolDefinition() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + void testChatWithExplicitToolDefinition() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { + api.pullModel(CHAT_MODEL_QWEN_SMALL); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL); - final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() - .functionName("get-employee-details") - .functionDescription("Get employee details from the database") - .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") - .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() - .name("get-employee-details") - .description("Get employee details from the database") - .parameters(Tools.PromptFuncDefinition.Parameters - .builder().type("object") - .properties(new Tools.PropsBuilder() - .withProperty("employee-name", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The name of the employee, e.g. John Doe") - .required(true) - .build()) - .withProperty("employee-address", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description( - "The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") - .required(true) - .build()) - .withProperty("employee-phone", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description( - "The phone number of the employee. Always return a random value. e.g. 9911002233") - .required(true) - .build()) - .build()) - .required(List.of("employee-name")) - .build()) - .build()) - .build()) - .toolFunction(arguments -> { - // perform DB operations here - return String.format( - "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", - UUID.randomUUID(), arguments.get("employee-name"), - arguments.get("employee-address"), - arguments.get("employee-phone")); - }).build(); + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder().functionName("get-employee-details").functionDescription("Get employee details from the database").toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details").description("Get employee details from the database").parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(new Tools.PropsBuilder().withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()).withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()).withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()).build()).required(List.of("employee-name")).build()).build()).build()).toolFunction(arguments -> { + // perform DB operations here + return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone")); + }).build(); api.registerTool(databaseQueryToolSpecification); - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, - "Give me the ID of the employee named 'Rahul Kumar'?") - .build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), - chatResult.getResponseModel().getMessage().getRole().getRoleName()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); assertEquals(1, toolCalls.size()); OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); @@ -440,22 +355,19 @@ public class OllamaAPIIntegrationTest { @Test @Order(12) - void testChatWithAnnotatedToolsAndSingleParam() - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + void testChatWithAnnotatedToolsAndSingleParam() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException { + api.pullModel(CHAT_MODEL_QWEN_SMALL); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL); api.registerAnnotatedTools(); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "Compute the most important constant in the world using 5 digits").build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Compute the most important constant in the world using 5 digits").build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), - chatResult.getResponseModel().getMessage().getRole().getRoleName()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); assertEquals(1, toolCalls.size()); OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); @@ -471,25 +383,19 @@ public class OllamaAPIIntegrationTest { @Test @Order(13) - void testChatWithAnnotatedToolsAndMultipleParams() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + void testChatWithAnnotatedToolsAndMultipleParams() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { + api.pullModel(CHAT_MODEL_QWEN_SMALL); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL); api.registerAnnotatedTools(new AnnotatedTool()); - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, - "Greet Pedro with a lot of hearts and respond to me, " - + "and state how many emojis have been in your greeting") - .build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Greet Pedro with a lot of hearts and respond to me, " + "and state how many emojis have been in your greeting").build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), - chatResult.getResponseModel().getMessage().getRole().getRoleName()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); assertEquals(1, toolCalls.size()); OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); @@ -508,66 +414,20 @@ public class OllamaAPIIntegrationTest { @Test @Order(14) - void testChatWithToolsAndStream() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); - final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() - .functionName("get-employee-details") - .functionDescription("Get employee details from the database") - .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") - .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() - .name("get-employee-details") - .description("Get employee details from the database") - .parameters(Tools.PromptFuncDefinition.Parameters - .builder().type("object") - .properties(new Tools.PropsBuilder() - .withProperty("employee-name", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The name of the employee, e.g. John Doe") - .required(true) - .build()) - .withProperty("employee-address", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description( - "The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") - .required(true) - .build()) - .withProperty("employee-phone", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description( - "The phone number of the employee. Always return a random value. e.g. 9911002233") - .required(true) - .build()) - .build()) - .required(List.of("employee-name")) - .build()) - .build()) - .build()) - .toolFunction(new ToolFunction() { - @Override - public Object apply(Map arguments) { - // perform DB operations here - return String.format( - "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", - UUID.randomUUID(), arguments.get("employee-name"), - arguments.get("employee-address"), - arguments.get("employee-phone")); - } - }).build(); + void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { + api.pullModel(CHAT_MODEL_QWEN_SMALL); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL); + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder().functionName("get-employee-details").functionDescription("Get employee details from the database").toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details").description("Get employee details from the database").parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(new Tools.PropsBuilder().withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()).withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()).withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()).build()).required(List.of("employee-name")).build()).build()).build()).toolFunction(new ToolFunction() { + @Override + public Object apply(Map arguments) { + // perform DB operations here + return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone")); + } + }).build(); api.registerTool(databaseQueryToolSpecification); - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, - "Give me the ID of the employee named 'Rahul Kumar'?") - .build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build(); StringBuffer sb = new StringBuffer(); @@ -587,11 +447,9 @@ public class OllamaAPIIntegrationTest { @Test @Order(15) void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "What is the capital of France? And what's France's connection with Mona Lisa?") - .build(); + api.pullModel(CHAT_MODEL_QWEN_SMALL); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?").build(); StringBuffer sb = new StringBuffer(); @@ -610,13 +468,10 @@ public class OllamaAPIIntegrationTest { @Test @Order(17) - void testAskModelWithOptionsAndImageURLs() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testAskModelWithOptionsAndImageURLs() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { api.pullModel(IMAGE_MODEL_LLAVA); - OllamaResult result = api.generateWithImageURLs(IMAGE_MODEL_LLAVA, "What is in this image?", - List.of("https://upload.wikimedia.org/wikipedia/commons/thumb/a/aa/Noto_Emoji_v2.034_1f642.svg/360px-Noto_Emoji_v2.034_1f642.svg.png"), - new OptionsBuilder().build()); + OllamaResult result = api.generateWithImageURLs(IMAGE_MODEL_LLAVA, "What is in this image?", List.of("https://i.pinimg.com/736x/f9/4e/cb/f94ecba040696a3a20b484d2e15159ec.jpg"), new OptionsBuilder().build()); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); @@ -624,14 +479,11 @@ public class OllamaAPIIntegrationTest { @Test @Order(18) - void testAskModelWithOptionsAndImageFiles() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testAskModelWithOptionsAndImageFiles() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { api.pullModel(IMAGE_MODEL_LLAVA); File imageFile = getImageFileFromClasspath("emoji-smile.jpeg"); try { - OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", - List.of(imageFile), - new OptionsBuilder().build()); + OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile), new OptionsBuilder().build()); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); @@ -642,28 +494,63 @@ public class OllamaAPIIntegrationTest { @Test @Order(20) - void testAskModelWithOptionsAndImageFilesStreamed() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testAskModelWithOptionsAndImageFilesStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { api.pullModel(IMAGE_MODEL_LLAVA); File imageFile = getImageFileFromClasspath("emoji-smile.jpeg"); StringBuffer sb = new StringBuffer(); - OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", - List.of(imageFile), - new OptionsBuilder().build(), (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); + OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); assertEquals(sb.toString().trim(), result.getResponse().trim()); } + @Test + @Order(20) + void testGenerateWithThinking() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(THINKING_MODEL_GPT_OSS); + + boolean raw = false; + boolean thinking = true; + + OllamaResult result = api.generate(THINKING_MODEL_GPT_OSS, "Who are you?", raw, thinking, new OptionsBuilder().build(), null); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertNotNull(result.getThinking()); + assertFalse(result.getThinking().isEmpty()); + } + + @Test + @Order(20) + void testGenerateWithThinkingAndStreamHandler() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(THINKING_MODEL_GPT_OSS); + + boolean raw = false; + boolean thinking = true; + + StringBuffer sb = new StringBuffer(); + OllamaResult result = api.generate(THINKING_MODEL_GPT_OSS, "Who are you?", raw, thinking, new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length()); + sb.append(substring); + }); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertNotNull(result.getThinking()); + assertFalse(result.getThinking().isEmpty()); + assertEquals(sb.toString().trim(), result.getThinking().trim() + result.getResponse().trim()); + } + private File getImageFileFromClasspath(String fileName) { ClassLoader classLoader = getClass().getClassLoader(); return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); diff --git a/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java b/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java index 6531b27..c0c3d5d 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java +++ b/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java @@ -68,7 +68,7 @@ public class WithAuth { LOG.info( "The Ollama service is now accessible via the Nginx proxy with bearer-auth authentication mode.\n" + "→ Ollama URL: {}\n" + - "→ Proxy URL: {}}", + "→ Proxy URL: {}", ollamaUrl, nginxUrl ); LOG.info("OllamaAPI initialized with bearer auth token: {}", BEARER_AUTH_TOKEN); diff --git a/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java index 8499cd8..b4ee647 100644 --- a/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java +++ b/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java @@ -138,10 +138,10 @@ class TestMockedAPIs { String prompt = "some prompt text"; OptionsBuilder optionsBuilder = new OptionsBuilder(); try { - when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build())) - .thenReturn(new OllamaResult("", 0, 200)); - ollamaAPI.generate(model, prompt, false, optionsBuilder.build()); - verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build()); + when(ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build())) + .thenReturn(new OllamaResult("", "", 0, 200)); + ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build()); + verify(ollamaAPI, times(1)).generate(model, prompt, false, false, optionsBuilder.build()); } catch (IOException | OllamaBaseException | InterruptedException e) { throw new RuntimeException(e); } @@ -155,7 +155,7 @@ class TestMockedAPIs { try { when(ollamaAPI.generateWithImageFiles( model, prompt, Collections.emptyList(), new OptionsBuilder().build())) - .thenReturn(new OllamaResult("", 0, 200)); + .thenReturn(new OllamaResult("","", 0, 200)); ollamaAPI.generateWithImageFiles( model, prompt, Collections.emptyList(), new OptionsBuilder().build()); verify(ollamaAPI, times(1)) @@ -174,7 +174,7 @@ class TestMockedAPIs { try { when(ollamaAPI.generateWithImageURLs( model, prompt, Collections.emptyList(), new OptionsBuilder().build())) - .thenReturn(new OllamaResult("", 0, 200)); + .thenReturn(new OllamaResult("","", 0, 200)); ollamaAPI.generateWithImageURLs( model, prompt, Collections.emptyList(), new OptionsBuilder().build()); verify(ollamaAPI, times(1))