diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 94c897c..34c7257 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -15,11 +15,12 @@ import io.github.ollama4j.exceptions.RoleNotFoundException; import io.github.ollama4j.exceptions.ToolInvocationException; import io.github.ollama4j.exceptions.ToolNotFoundException; import io.github.ollama4j.models.chat.*; +import io.github.ollama4j.models.chat.OllamaChatTokenHandler; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateRequest; -import io.github.ollama4j.models.generate.OllamaStreamHandler; -import io.github.ollama4j.models.generate.OllamaTokenHandler; +import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver; +import io.github.ollama4j.models.generate.OllamaGenerateTokenHandler; import io.github.ollama4j.models.ps.ModelsProcessResponse; import io.github.ollama4j.models.request.*; import io.github.ollama4j.models.response.*; @@ -118,7 +119,7 @@ public class OllamaAPI { } else { this.host = host; } - LOG.info("Ollama API initialized with host: {}", this.host); + LOG.info("Ollama4j client initialized. Connected to Ollama server at: {}", this.host); } /** @@ -470,16 +471,26 @@ public class OllamaAPI { .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) .build(); HttpClient client = HttpClient.newHttpClient(); - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + HttpResponse response = + client.send(request, HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); - String responseString = response.body(); if (statusCode != 200) { - throw new OllamaBaseException(statusCode + " - " + responseString); + String errorBody = new String(response.body().readAllBytes(), StandardCharsets.UTF_8); + throw new OllamaBaseException(statusCode + " - " + errorBody); } - if (responseString.contains("error")) { - throw new OllamaBaseException(responseString); + try (BufferedReader reader = + new BufferedReader( + new InputStreamReader(response.body(), StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + ModelPullResponse res = + Utils.getObjectMapper().readValue(line, ModelPullResponse.class); + LOG.debug(res.getStatus()); + if (res.getError() != null) { + throw new OllamaBaseException(res.getError()); + } + } } - LOG.debug(responseString); } /** @@ -559,98 +570,32 @@ public class OllamaAPI { } } - /** - * Generate response for a question to a model running on Ollama server. This is a sync/blocking - * call. This API does not support "thinking" models. - * - * @param model the ollama model to ask the question to - * @param prompt the prompt/question text - * @param raw if true no formatting will be applied to the prompt. You may choose to use the raw - * parameter if you are specifying a full templated prompt in your request to the API - * @param options the Options object - More - * details on the options - * @param responseStreamHandler optional callback consumer that will be applied every time a - * streamed response is received. If not set, the stream parameter of the request is set to - * false. - * @return OllamaResult that includes response text and time taken for response - * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request - * @throws InterruptedException if the operation is interrupted - */ public OllamaResult generate( String model, String prompt, boolean raw, + boolean think, Options options, - OllamaStreamHandler responseStreamHandler) + OllamaGenerateStreamObserver streamObserver) throws OllamaBaseException, IOException, InterruptedException { + + // Create the OllamaGenerateRequest and configure common properties OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); ollamaRequestModel.setRaw(raw); - ollamaRequestModel.setThink(false); + ollamaRequestModel.setThink(think); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSyncForOllamaRequestModel(ollamaRequestModel, null, responseStreamHandler); - } - /** - * Generate thinking and response tokens for a question to a thinking model running on Ollama - * server. This is a sync/blocking call. - * - * @param model the ollama model to ask the question to - * @param prompt the prompt/question text - * @param raw if true no formatting will be applied to the prompt. You may choose to use the raw - * parameter if you are specifying a full templated prompt in your request to the API - * @param options the Options object - More - * details on the options - * @param responseStreamHandler optional callback consumer that will be applied every time a - * streamed response is received. If not set, the stream parameter of the request is set to - * false. - * @return OllamaResult that includes response text and time taken for response - * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request - * @throws InterruptedException if the operation is interrupted - */ - public OllamaResult generate( - String model, - String prompt, - boolean raw, - Options options, - OllamaStreamHandler thinkingStreamHandler, - OllamaStreamHandler responseStreamHandler) - throws OllamaBaseException, IOException, InterruptedException { - OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); - ollamaRequestModel.setRaw(raw); - ollamaRequestModel.setThink(true); - ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSyncForOllamaRequestModel( - ollamaRequestModel, thinkingStreamHandler, responseStreamHandler); - } - - /** - * Generates response using the specified AI model and prompt (in blocking mode). - * - *

Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)} - * - * @param model The name or identifier of the AI model to use for generating the response. - * @param prompt The input text or prompt to provide to the AI model. - * @param raw In some cases, you may wish to bypass the templating system and provide a full - * prompt. In this case, you can use the raw parameter to disable templating. Also note that - * raw mode will not return a context. - * @param options Additional options or configurations to use when generating the response. - * @param think if true the model will "think" step-by-step before generating the final response - * @return {@link OllamaResult} - * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request - * @throws InterruptedException if the operation is interrupted - */ - public OllamaResult generate( - String model, String prompt, boolean raw, boolean think, Options options) - throws OllamaBaseException, IOException, InterruptedException { + // Based on 'think' flag, choose the appropriate stream handler(s) if (think) { - return generate(model, prompt, raw, options, null, null); + // Call with thinking + return generateSyncForOllamaRequestModel( + ollamaRequestModel, + streamObserver.getThinkingStreamHandler(), + streamObserver.getResponseStreamHandler()); } else { - return generate(model, prompt, raw, options, null); + // Call without thinking + return generateSyncForOllamaRequestModel( + ollamaRequestModel, null, streamObserver.getResponseStreamHandler()); } } @@ -668,7 +613,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted. */ @SuppressWarnings("LoggingSimilarMessage") - public OllamaResult generate(String model, String prompt, Map format) + public OllamaResult generateWithFormat(String model, String prompt, Map format) throws OllamaBaseException, IOException, InterruptedException { URI uri = URI.create(this.host + "/api/generate"); @@ -767,7 +712,7 @@ public class OllamaAPI { * @throws ToolInvocationException if a tool call fails to execute */ public OllamaToolsResult generateWithTools( - String model, String prompt, Options options, OllamaStreamHandler streamHandler) + String model, String prompt, Options options, OllamaGenerateTokenHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { boolean raw = true; OllamaToolsResult toolResult = new OllamaToolsResult(); @@ -782,7 +727,14 @@ public class OllamaAPI { prompt = promptBuilder.build(); } - OllamaResult result = generate(model, prompt, raw, options, streamHandler); + OllamaResult result = + generate( + model, + prompt, + raw, + false, + options, + new OllamaGenerateStreamObserver(null, streamHandler)); toolResult.setModelResult(result); String toolsResponse = result.getResponse(); @@ -898,7 +850,7 @@ public class OllamaAPI { List images, Options options, Map format, - OllamaStreamHandler streamHandler) + OllamaGenerateTokenHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { List encodedImages = new ArrayList<>(); for (Object image : images) { @@ -947,7 +899,7 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaChatResult chat(OllamaChatRequest request, OllamaTokenHandler tokenHandler) + public OllamaChatResult chat(OllamaChatRequest request, OllamaChatTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds); @@ -1233,8 +1185,8 @@ public class OllamaAPI { */ private OllamaResult generateSyncForOllamaRequestModel( OllamaGenerateRequest ollamaRequestModel, - OllamaStreamHandler thinkingStreamHandler, - OllamaStreamHandler responseStreamHandler) + OllamaGenerateTokenHandler thinkingStreamHandler, + OllamaGenerateTokenHandler responseStreamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds); diff --git a/src/main/java/io/github/ollama4j/impl/ConsoleOutputChatTokenHandler.java b/src/main/java/io/github/ollama4j/impl/ConsoleOutputChatTokenHandler.java new file mode 100644 index 0000000..ea0f728 --- /dev/null +++ b/src/main/java/io/github/ollama4j/impl/ConsoleOutputChatTokenHandler.java @@ -0,0 +1,18 @@ +/* + * Ollama4j - Java library for interacting with Ollama server. + * Copyright (c) 2025 Amith Koujalgi and contributors. + * + * Licensed under the MIT License (the "License"); + * you may not use this file except in compliance with the License. + * +*/ +package io.github.ollama4j.impl; + +import io.github.ollama4j.models.chat.OllamaChatStreamObserver; + +public final class ConsoleOutputChatTokenHandler extends OllamaChatStreamObserver { + public ConsoleOutputChatTokenHandler() { + setThinkingStreamHandler(new ConsoleOutputGenerateTokenHandler()); + setResponseStreamHandler(new ConsoleOutputGenerateTokenHandler()); + } +} diff --git a/src/main/java/io/github/ollama4j/impl/ConsoleOutputStreamHandler.java b/src/main/java/io/github/ollama4j/impl/ConsoleOutputGenerateTokenHandler.java similarity index 52% rename from src/main/java/io/github/ollama4j/impl/ConsoleOutputStreamHandler.java rename to src/main/java/io/github/ollama4j/impl/ConsoleOutputGenerateTokenHandler.java index a5a9ef4..b303315 100644 --- a/src/main/java/io/github/ollama4j/impl/ConsoleOutputStreamHandler.java +++ b/src/main/java/io/github/ollama4j/impl/ConsoleOutputGenerateTokenHandler.java @@ -8,15 +8,11 @@ */ package io.github.ollama4j.impl; -import io.github.ollama4j.models.generate.OllamaStreamHandler; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class ConsoleOutputStreamHandler implements OllamaStreamHandler { - private static final Logger LOG = LoggerFactory.getLogger(ConsoleOutputStreamHandler.class); +import io.github.ollama4j.models.generate.OllamaGenerateTokenHandler; +public class ConsoleOutputGenerateTokenHandler implements OllamaGenerateTokenHandler { @Override public void accept(String message) { - LOG.info(message); + System.out.print(message); } } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java index 2b18c73..f969599 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java @@ -21,7 +21,9 @@ import lombok.*; /** * Defines a single Message to be used inside a chat request against the ollama /api/chat endpoint. * - * @see Generate chat completion + * @see Generate + * chat completion */ @Data @AllArgsConstructor @@ -32,7 +34,9 @@ public class OllamaChatMessage { @NonNull private OllamaChatMessageRole role; - @NonNull private String content; + @JsonProperty("content") + @NonNull + private String response; private String thinking; diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java index 1130da4..88b470a 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java @@ -114,11 +114,18 @@ public class OllamaChatRequestBuilder { imageURLConnectTimeoutSeconds, imageURLReadTimeoutSeconds)); } catch (InterruptedException e) { - LOG.error("Failed to load image from URL: {}. Cause: {}", imageUrl, e); - throw e; + LOG.error("Failed to load image from URL: '{}'. Cause: {}", imageUrl, e); + Thread.currentThread().interrupt(); + throw new InterruptedException( + "Interrupted while loading image from URL: " + imageUrl); } catch (IOException e) { - LOG.warn("Failed to load image from URL: {}. Cause: {}", imageUrl, e); - throw e; + LOG.error( + "IOException occurred while loading image from URL '{}'. Cause: {}", + imageUrl, + e.getMessage(), + e); + throw new IOException( + "IOException while loading image from URL: " + imageUrl, e); } } } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResponseModel.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResponseModel.java index 1705604..5c05a94 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResponseModel.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResponseModel.java @@ -8,18 +8,18 @@ */ package io.github.ollama4j.models.chat; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; import lombok.Data; @Data +@JsonIgnoreProperties(ignoreUnknown = true) public class OllamaChatResponseModel { private String model; private @JsonProperty("created_at") String createdAt; private @JsonProperty("done_reason") String doneReason; - private OllamaChatMessage message; private boolean done; - private String error; private List context; private @JsonProperty("total_duration") Long totalDuration; private @JsonProperty("load_duration") Long loadDuration; @@ -27,4 +27,6 @@ public class OllamaChatResponseModel { private @JsonProperty("eval_duration") Long evalDuration; private @JsonProperty("prompt_eval_count") Integer promptEvalCount; private @JsonProperty("eval_count") Integer evalCount; + private String error; + private OllamaChatMessage message; } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java index 1495eef..e77f4fe 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java @@ -47,7 +47,7 @@ public class OllamaChatResult { @Deprecated public String getResponse() { - return responseModel != null ? responseModel.getMessage().getContent() : ""; + return responseModel != null ? responseModel.getMessage().getResponse() : ""; } @Deprecated diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java index 2c38d61..b2bf91b 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java @@ -8,16 +8,17 @@ */ package io.github.ollama4j.models.chat; -import io.github.ollama4j.models.generate.OllamaStreamHandler; -import io.github.ollama4j.models.generate.OllamaTokenHandler; -import lombok.RequiredArgsConstructor; +import io.github.ollama4j.models.generate.OllamaGenerateTokenHandler; +import lombok.AllArgsConstructor; +import lombok.NoArgsConstructor; +import lombok.Setter; -@RequiredArgsConstructor -public class OllamaChatStreamObserver implements OllamaTokenHandler { - private final OllamaStreamHandler thinkingStreamHandler; - private final OllamaStreamHandler responseStreamHandler; - - private String message = ""; +@Setter +@NoArgsConstructor +@AllArgsConstructor +public class OllamaChatStreamObserver implements OllamaChatTokenHandler { + private OllamaGenerateTokenHandler thinkingStreamHandler; + private OllamaGenerateTokenHandler responseStreamHandler; @Override public void accept(OllamaChatResponseModel token) { @@ -26,33 +27,19 @@ public class OllamaChatStreamObserver implements OllamaTokenHandler { } String thinking = token.getMessage().getThinking(); - String content = token.getMessage().getContent(); + String response = token.getMessage().getResponse(); boolean hasThinking = thinking != null && !thinking.isEmpty(); - boolean hasContent = !content.isEmpty(); - - // if (hasThinking && !hasContent) { - //// message += thinking; - // message = thinking; - // } else { - //// message += content; - // message = content; - // } - // - // responseStreamHandler.accept(message); - - if (!hasContent && hasThinking && thinkingStreamHandler != null) { - // message = message + thinking; + boolean hasResponse = response != null && !response.isEmpty(); + if (!hasResponse && hasThinking && thinkingStreamHandler != null) { // use only new tokens received, instead of appending the tokens to the previous // ones and sending the full string again thinkingStreamHandler.accept(thinking); - } else if (hasContent && responseStreamHandler != null) { - // message = message + response; - + } else if (hasResponse) { // use only new tokens received, instead of appending the tokens to the previous // ones and sending the full string again - responseStreamHandler.accept(content); + responseStreamHandler.accept(response); } } } diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaTokenHandler.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatTokenHandler.java similarity index 60% rename from src/main/java/io/github/ollama4j/models/generate/OllamaTokenHandler.java rename to src/main/java/io/github/ollama4j/models/chat/OllamaChatTokenHandler.java index 78b325b..fba39df 100644 --- a/src/main/java/io/github/ollama4j/models/generate/OllamaTokenHandler.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatTokenHandler.java @@ -6,9 +6,8 @@ * you may not use this file except in compliance with the License. * */ -package io.github.ollama4j.models.generate; +package io.github.ollama4j.models.chat; -import io.github.ollama4j.models.chat.OllamaChatResponseModel; import java.util.function.Consumer; -public interface OllamaTokenHandler extends Consumer {} +public interface OllamaChatTokenHandler extends Consumer {} diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java index 091738d..bf33133 100644 --- a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java +++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java @@ -18,15 +18,15 @@ import lombok.Data; public class OllamaGenerateResponseModel { private String model; private @JsonProperty("created_at") String createdAt; - private String response; - private String thinking; - private boolean done; private @JsonProperty("done_reason") String doneReason; + private boolean done; private List context; private @JsonProperty("total_duration") Long totalDuration; private @JsonProperty("load_duration") Long loadDuration; - private @JsonProperty("prompt_eval_count") Integer promptEvalCount; private @JsonProperty("prompt_eval_duration") Long promptEvalDuration; - private @JsonProperty("eval_count") Integer evalCount; private @JsonProperty("eval_duration") Long evalDuration; + private @JsonProperty("prompt_eval_count") Integer promptEvalCount; + private @JsonProperty("eval_count") Integer evalCount; + private String response; + private String thinking; } diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java index 8a0164a..441da71 100644 --- a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java +++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java @@ -10,18 +10,18 @@ package io.github.ollama4j.models.generate; import java.util.ArrayList; import java.util.List; +import lombok.Getter; +@Getter public class OllamaGenerateStreamObserver { - - private final OllamaStreamHandler thinkingStreamHandler; - private final OllamaStreamHandler responseStreamHandler; + private final OllamaGenerateTokenHandler thinkingStreamHandler; + private final OllamaGenerateTokenHandler responseStreamHandler; private final List responseParts = new ArrayList<>(); - private String message = ""; - public OllamaGenerateStreamObserver( - OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) { + OllamaGenerateTokenHandler thinkingStreamHandler, + OllamaGenerateTokenHandler responseStreamHandler) { this.responseStreamHandler = responseStreamHandler; this.thinkingStreamHandler = thinkingStreamHandler; } @@ -39,14 +39,10 @@ public class OllamaGenerateStreamObserver { boolean hasThinking = thinking != null && !thinking.isEmpty(); if (!hasResponse && hasThinking && thinkingStreamHandler != null) { - // message = message + thinking; - // use only new tokens received, instead of appending the tokens to the previous // ones and sending the full string again thinkingStreamHandler.accept(thinking); } else if (hasResponse && responseStreamHandler != null) { - // message = message + response; - // use only new tokens received, instead of appending the tokens to the previous // ones and sending the full string again responseStreamHandler.accept(response); diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaStreamHandler.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateTokenHandler.java similarity index 83% rename from src/main/java/io/github/ollama4j/models/generate/OllamaStreamHandler.java rename to src/main/java/io/github/ollama4j/models/generate/OllamaGenerateTokenHandler.java index 810985b..d8d9d01 100644 --- a/src/main/java/io/github/ollama4j/models/generate/OllamaStreamHandler.java +++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateTokenHandler.java @@ -10,6 +10,6 @@ package io.github.ollama4j.models.generate; import java.util.function.Consumer; -public interface OllamaStreamHandler extends Consumer { +public interface OllamaGenerateTokenHandler extends Consumer { void accept(String message); } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java index c48d21e..a5fdfb0 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java @@ -12,7 +12,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.models.chat.*; -import io.github.ollama4j.models.generate.OllamaTokenHandler; +import io.github.ollama4j.models.chat.OllamaChatTokenHandler; import io.github.ollama4j.models.response.OllamaErrorResponse; import io.github.ollama4j.utils.Utils; import java.io.BufferedReader; @@ -36,7 +36,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); - private OllamaTokenHandler tokenHandler; + private OllamaChatTokenHandler tokenHandler; public OllamaChatEndpointCaller(String host, Auth auth, long requestTimeoutSeconds) { super(host, auth, requestTimeoutSeconds); @@ -73,7 +73,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { if (message.getThinking() != null) { thinkingBuffer.append(message.getThinking()); } else { - responseBuffer.append(message.getContent()); + responseBuffer.append(message.getResponse()); } if (tokenHandler != null) { tokenHandler.accept(ollamaResponseModel); @@ -86,7 +86,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { } } - public OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler) + public OllamaChatResult call(OllamaChatRequest body, OllamaChatTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException { this.tokenHandler = tokenHandler; return callSync(body); @@ -127,7 +127,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls(); } if (finished && body.stream) { - ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString()); + ollamaChatResponseModel.getMessage().setResponse(responseBuffer.toString()); ollamaChatResponseModel.getMessage().setThinking(thinkingBuffer.toString()); break; } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java index 3100f38..9c3387a 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java @@ -12,7 +12,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.models.generate.OllamaGenerateResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver; -import io.github.ollama4j.models.generate.OllamaStreamHandler; +import io.github.ollama4j.models.generate.OllamaGenerateTokenHandler; import io.github.ollama4j.models.response.OllamaErrorResponse; import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.utils.OllamaRequestBody; @@ -69,8 +69,8 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { public OllamaResult call( OllamaRequestBody body, - OllamaStreamHandler thinkingStreamHandler, - OllamaStreamHandler responseStreamHandler) + OllamaGenerateTokenHandler thinkingStreamHandler, + OllamaGenerateTokenHandler responseStreamHandler) throws OllamaBaseException, IOException, InterruptedException { responseStreamObserver = new OllamaGenerateStreamObserver(thinkingStreamHandler, responseStreamHandler); diff --git a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java index 2c0dd14..1ed8797 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java +++ b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java @@ -13,9 +13,12 @@ import static org.junit.jupiter.api.Assertions.*; import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.ToolInvocationException; +import io.github.ollama4j.impl.ConsoleOutputChatTokenHandler; +import io.github.ollama4j.impl.ConsoleOutputGenerateTokenHandler; import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; +import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver; import io.github.ollama4j.models.response.Model; import io.github.ollama4j.models.response.ModelDetail; import io.github.ollama4j.models.response.OllamaResult; @@ -56,10 +59,41 @@ class OllamaAPIIntegrationTest { @BeforeAll static void setUp() { + int requestTimeoutSeconds = 60; + int numberOfRetriesForModelPull = 5; + try { - boolean useExternalOllamaHost = - Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST")); - String ollamaHost = System.getenv("OLLAMA_HOST"); + // Try to get from env vars first + String useExternalOllamaHostEnv = System.getenv("USE_EXTERNAL_OLLAMA_HOST"); + String ollamaHostEnv = System.getenv("OLLAMA_HOST"); + + boolean useExternalOllamaHost; + String ollamaHost; + + if (useExternalOllamaHostEnv == null && ollamaHostEnv == null) { + // Fallback to test-config.properties from classpath + Properties props = new Properties(); + try { + props.load( + OllamaAPIIntegrationTest.class + .getClassLoader() + .getResourceAsStream("test-config.properties")); + } catch (Exception e) { + throw new RuntimeException( + "Could not load test-config.properties from classpath", e); + } + useExternalOllamaHost = + Boolean.parseBoolean( + props.getProperty("USE_EXTERNAL_OLLAMA_HOST", "false")); + ollamaHost = props.getProperty("OLLAMA_HOST"); + requestTimeoutSeconds = + Integer.parseInt(props.getProperty("REQUEST_TIMEOUT_SECONDS")); + numberOfRetriesForModelPull = + Integer.parseInt(props.getProperty("NUMBER_RETRIES_FOR_MODEL_PULL")); + } else { + useExternalOllamaHost = Boolean.parseBoolean(useExternalOllamaHostEnv); + ollamaHost = ollamaHostEnv; + } if (useExternalOllamaHost) { LOG.info("Using external Ollama host..."); @@ -90,8 +124,8 @@ class OllamaAPIIntegrationTest { + ":" + ollama.getMappedPort(internalPort)); } - api.setRequestTimeoutSeconds(120); - api.setNumberOfRetriesForModelPull(5); + api.setRequestTimeoutSeconds(requestTimeoutSeconds); + api.setNumberOfRetriesForModelPull(numberOfRetriesForModelPull); } @Test @@ -187,7 +221,7 @@ class OllamaAPIIntegrationTest { }); format.put("required", List.of("isNoon")); - OllamaResult result = api.generate(TOOLS_MODEL, prompt, format); + OllamaResult result = api.generateWithFormat(TOOLS_MODEL, prompt, format); assertNotNull(result); assertNotNull(result.getResponse()); @@ -210,7 +244,8 @@ class OllamaAPIIntegrationTest { + " Lisa?", raw, thinking, - new OptionsBuilder().build()); + new OptionsBuilder().build(), + new OllamaGenerateStreamObserver(null, null)); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); @@ -228,8 +263,10 @@ class OllamaAPIIntegrationTest { "What is the capital of France? And what's France's connection with Mona" + " Lisa?", raw, + false, new OptionsBuilder().build(), - LOG::info); + new OllamaGenerateStreamObserver( + null, new ConsoleOutputGenerateTokenHandler())); assertNotNull(result); assertNotNull(result.getResponse()); @@ -263,7 +300,7 @@ class OllamaAPIIntegrationTest { assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); - assertFalse(chatResult.getResponseModel().getMessage().getContent().isEmpty()); + assertFalse(chatResult.getResponseModel().getMessage().getResponse().isEmpty()); } @Test @@ -296,9 +333,13 @@ class OllamaAPIIntegrationTest { assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); - assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank()); + assertFalse(chatResult.getResponseModel().getMessage().getResponse().isBlank()); assertTrue( - chatResult.getResponseModel().getMessage().getContent().contains(expectedResponse)); + chatResult + .getResponseModel() + .getMessage() + .getResponse() + .contains(expectedResponse)); assertEquals(3, chatResult.getChatHistory().size()); } @@ -515,16 +556,7 @@ class OllamaAPIIntegrationTest { .withOptions(new OptionsBuilder().setTemperature(0.9f).build()) .build(); - OllamaChatResult chatResult = - api.chat( - requestModel, - new OllamaChatStreamObserver( - s -> { - LOG.info(s.toUpperCase()); - }, - s -> { - LOG.info(s.toLowerCase()); - })); + OllamaChatResult chatResult = api.chat(requestModel, new ConsoleOutputChatTokenHandler()); assertNotNull(chatResult, "chatResult should not be null"); assertNotNull(chatResult.getResponseModel(), "Response model should not be null"); @@ -670,20 +702,11 @@ class OllamaAPIIntegrationTest { .build(); requestModel.setThink(false); - OllamaChatResult chatResult = - api.chat( - requestModel, - new OllamaChatStreamObserver( - s -> { - LOG.info(s.toUpperCase()); - }, - s -> { - LOG.info(s.toLowerCase()); - })); + OllamaChatResult chatResult = api.chat(requestModel, new ConsoleOutputChatTokenHandler()); assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); - assertNotNull(chatResult.getResponseModel().getMessage().getContent()); + assertNotNull(chatResult.getResponseModel().getMessage().getResponse()); } @Test @@ -706,21 +729,12 @@ class OllamaAPIIntegrationTest { .withKeepAlive("0m") .build(); - OllamaChatResult chatResult = - api.chat( - requestModel, - new OllamaChatStreamObserver( - s -> { - LOG.info(s.toUpperCase()); - }, - s -> { - LOG.info(s.toLowerCase()); - })); + OllamaChatResult chatResult = api.chat(requestModel, new ConsoleOutputChatTokenHandler()); assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); - assertNotNull(chatResult.getResponseModel().getMessage().getContent()); + assertNotNull(chatResult.getResponseModel().getMessage().getResponse()); } @Test @@ -859,7 +873,8 @@ class OllamaAPIIntegrationTest { "Who are you?", raw, think, - new OptionsBuilder().build()); + new OptionsBuilder().build(), + new OllamaGenerateStreamObserver(null, null)); assertNotNull(result); assertNotNull(result.getResponse()); assertNotNull(result.getThinking()); @@ -876,13 +891,15 @@ class OllamaAPIIntegrationTest { THINKING_TOOL_MODEL, "Who are you?", raw, + true, new OptionsBuilder().build(), - thinkingToken -> { - LOG.info(thinkingToken.toUpperCase()); - }, - resToken -> { - LOG.info(resToken.toLowerCase()); - }); + new OllamaGenerateStreamObserver( + thinkingToken -> { + LOG.info(thinkingToken.toUpperCase()); + }, + resToken -> { + LOG.info(resToken.toLowerCase()); + })); assertNotNull(result); assertNotNull(result.getResponse()); assertNotNull(result.getThinking()); diff --git a/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java b/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java index 59433b4..312b1fb 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java +++ b/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java @@ -203,7 +203,7 @@ public class WithAuth { }); format.put("required", List.of("isNoon")); - OllamaResult result = api.generate(model, prompt, format); + OllamaResult result = api.generateWithFormat(model, prompt, format); assertNotNull(result); assertNotNull(result.getResponse()); diff --git a/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java index 6ecc78d..4fa2a39 100644 --- a/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java +++ b/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java @@ -18,6 +18,7 @@ import io.github.ollama4j.exceptions.RoleNotFoundException; import io.github.ollama4j.models.chat.OllamaChatMessageRole; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; +import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver; import io.github.ollama4j.models.request.CustomModelRequest; import io.github.ollama4j.models.response.ModelDetail; import io.github.ollama4j.models.response.OllamaAsyncResultStreamer; @@ -170,12 +171,13 @@ class TestMockedAPIs { String model = "llama2"; String prompt = "some prompt text"; OptionsBuilder optionsBuilder = new OptionsBuilder(); + OllamaGenerateStreamObserver observer = new OllamaGenerateStreamObserver(null, null); try { - when(ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build())) + when(ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build(), observer)) .thenReturn(new OllamaResult("", "", 0, 200)); - ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build()); + ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build(), observer); verify(ollamaAPI, times(1)) - .generate(model, prompt, false, false, optionsBuilder.build()); + .generate(model, prompt, false, false, optionsBuilder.build(), observer); } catch (IOException | OllamaBaseException | InterruptedException e) { throw new RuntimeException(e); } diff --git a/src/test/java/io/github/ollama4j/unittests/TestOllamaChatRequestBuilder.java b/src/test/java/io/github/ollama4j/unittests/TestOllamaChatRequestBuilder.java index 636c266..356504d 100644 --- a/src/test/java/io/github/ollama4j/unittests/TestOllamaChatRequestBuilder.java +++ b/src/test/java/io/github/ollama4j/unittests/TestOllamaChatRequestBuilder.java @@ -59,6 +59,6 @@ class TestOllamaChatRequestBuilder { assertNotNull(req.getMessages()); assert (!req.getMessages().isEmpty()); OllamaChatMessage msg = req.getMessages().get(0); - assertNotNull(msg.getContent()); + assertNotNull(msg.getResponse()); } } diff --git a/src/test/java/io/github/ollama4j/unittests/TestOptionsAndUtils.java b/src/test/java/io/github/ollama4j/unittests/TestOptionsAndUtils.java index 3973a08..409237c 100644 --- a/src/test/java/io/github/ollama4j/unittests/TestOptionsAndUtils.java +++ b/src/test/java/io/github/ollama4j/unittests/TestOptionsAndUtils.java @@ -67,12 +67,9 @@ class TestOptionsAndUtils { @Test void testOptionsBuilderRejectsUnsupportedCustomType() { + OptionsBuilder builder = new OptionsBuilder(); assertThrows( - IllegalArgumentException.class, - () -> { - OptionsBuilder builder = new OptionsBuilder(); - builder.setCustomOption("bad", new Object()); - }); + IllegalArgumentException.class, () -> builder.setCustomOption("bad", new Object())); } @Test diff --git a/src/test/resources/logback.xml b/src/test/resources/logback.xml index bd21aa4..4100fc8 100644 --- a/src/test/resources/logback.xml +++ b/src/test/resources/logback.xml @@ -10,6 +10,14 @@ + + + + + + + + diff --git a/src/test/resources/test-config.properties b/src/test/resources/test-config.properties index bfa0251..62f46dd 100644 --- a/src/test/resources/test-config.properties +++ b/src/test/resources/test-config.properties @@ -1,4 +1,4 @@ -ollama.url=http://localhost:11434 -ollama.model=llama3.2:1b -ollama.model.image=llava:latest -ollama.request-timeout-seconds=120 \ No newline at end of file +USE_EXTERNAL_OLLAMA_HOST=true +OLLAMA_HOST=http://192.168.29.229:11434/ +REQUEST_TIMEOUT_SECONDS=120 +NUMBER_RETRIES_FOR_MODEL_PULL=3 \ No newline at end of file