Refactor token handler interfaces and improve streaming

Renamed and refactored token handler interfaces for chat and generate modules to improve clarity and separation. Updated related classes and method signatures to use new handler types. Enhanced error handling and logging in chat and generate request builders. Updated tests and integration code to use new handler classes and configuration properties. Suppressed verbose logs from Docker and Testcontainers in test logging configuration.
This commit is contained in:
amithkoujalgi 2025-09-19 18:05:38 +05:30
parent d118958ac1
commit cb0f71ba63
No known key found for this signature in database
GPG Key ID: E29A37746AF94B70
21 changed files with 216 additions and 231 deletions

View File

@ -15,11 +15,12 @@ import io.github.ollama4j.exceptions.RoleNotFoundException;
import io.github.ollama4j.exceptions.ToolInvocationException; import io.github.ollama4j.exceptions.ToolInvocationException;
import io.github.ollama4j.exceptions.ToolNotFoundException; import io.github.ollama4j.exceptions.ToolNotFoundException;
import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.chat.OllamaChatTokenHandler;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateRequest; import io.github.ollama4j.models.generate.OllamaGenerateRequest;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
import io.github.ollama4j.models.generate.OllamaTokenHandler; import io.github.ollama4j.models.generate.OllamaGenerateTokenHandler;
import io.github.ollama4j.models.ps.ModelsProcessResponse; import io.github.ollama4j.models.ps.ModelsProcessResponse;
import io.github.ollama4j.models.request.*; import io.github.ollama4j.models.request.*;
import io.github.ollama4j.models.response.*; import io.github.ollama4j.models.response.*;
@ -118,7 +119,7 @@ public class OllamaAPI {
} else { } else {
this.host = host; this.host = host;
} }
LOG.info("Ollama API initialized with host: {}", this.host); LOG.info("Ollama4j client initialized. Connected to Ollama server at: {}", this.host);
} }
/** /**
@ -470,16 +471,26 @@ public class OllamaAPI {
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.build(); .build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<InputStream> response =
client.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseString = response.body();
if (statusCode != 200) { if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString); String errorBody = new String(response.body().readAllBytes(), StandardCharsets.UTF_8);
throw new OllamaBaseException(statusCode + " - " + errorBody);
}
try (BufferedReader reader =
new BufferedReader(
new InputStreamReader(response.body(), StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
ModelPullResponse res =
Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
LOG.debug(res.getStatus());
if (res.getError() != null) {
throw new OllamaBaseException(res.getError());
}
} }
if (responseString.contains("error")) {
throw new OllamaBaseException(responseString);
} }
LOG.debug(responseString);
} }
/** /**
@ -559,98 +570,32 @@ public class OllamaAPI {
} }
} }
/**
* Generate response for a question to a model running on Ollama server. This is a sync/blocking
* call. This API does not support "thinking" models.
*
* @param model the ollama model to ask the question to
* @param prompt the prompt/question text
* @param raw if true no formatting will be applied to the prompt. You may choose to use the raw
* parameter if you are specifying a full templated prompt in your request to the API
* @param options the Options object - <a href=
* "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
* details on the options</a>
* @param responseStreamHandler optional callback consumer that will be applied every time a
* streamed response is received. If not set, the stream parameter of the request is set to
* false.
* @return OllamaResult that includes response text and time taken for response
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generate( public OllamaResult generate(
String model, String model,
String prompt, String prompt,
boolean raw, boolean raw,
boolean think,
Options options, Options options,
OllamaStreamHandler responseStreamHandler) OllamaGenerateStreamObserver streamObserver)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
// Create the OllamaGenerateRequest and configure common properties
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw); ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setThink(false); ollamaRequestModel.setThink(think);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel, null, responseStreamHandler);
}
/** // Based on 'think' flag, choose the appropriate stream handler(s)
* Generate thinking and response tokens for a question to a thinking model running on Ollama
* server. This is a sync/blocking call.
*
* @param model the ollama model to ask the question to
* @param prompt the prompt/question text
* @param raw if true no formatting will be applied to the prompt. You may choose to use the raw
* parameter if you are specifying a full templated prompt in your request to the API
* @param options the Options object - <a href=
* "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
* details on the options</a>
* @param responseStreamHandler optional callback consumer that will be applied every time a
* streamed response is received. If not set, the stream parameter of the request is set to
* false.
* @return OllamaResult that includes response text and time taken for response
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generate(
String model,
String prompt,
boolean raw,
Options options,
OllamaStreamHandler thinkingStreamHandler,
OllamaStreamHandler responseStreamHandler)
throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setThink(true);
ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(
ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
}
/**
* Generates response using the specified AI model and prompt (in blocking mode).
*
* <p>Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
*
* @param model The name or identifier of the AI model to use for generating the response.
* @param prompt The input text or prompt to provide to the AI model.
* @param raw In some cases, you may wish to bypass the templating system and provide a full
* prompt. In this case, you can use the raw parameter to disable templating. Also note that
* raw mode will not return a context.
* @param options Additional options or configurations to use when generating the response.
* @param think if true the model will "think" step-by-step before generating the final response
* @return {@link OllamaResult}
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generate(
String model, String prompt, boolean raw, boolean think, Options options)
throws OllamaBaseException, IOException, InterruptedException {
if (think) { if (think) {
return generate(model, prompt, raw, options, null, null); // Call with thinking
return generateSyncForOllamaRequestModel(
ollamaRequestModel,
streamObserver.getThinkingStreamHandler(),
streamObserver.getResponseStreamHandler());
} else { } else {
return generate(model, prompt, raw, options, null); // Call without thinking
return generateSyncForOllamaRequestModel(
ollamaRequestModel, null, streamObserver.getResponseStreamHandler());
} }
} }
@ -668,7 +613,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted. * @throws InterruptedException if the operation is interrupted.
*/ */
@SuppressWarnings("LoggingSimilarMessage") @SuppressWarnings("LoggingSimilarMessage")
public OllamaResult generate(String model, String prompt, Map<String, Object> format) public OllamaResult generateWithFormat(String model, String prompt, Map<String, Object> format)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
@ -767,7 +712,7 @@ public class OllamaAPI {
* @throws ToolInvocationException if a tool call fails to execute * @throws ToolInvocationException if a tool call fails to execute
*/ */
public OllamaToolsResult generateWithTools( public OllamaToolsResult generateWithTools(
String model, String prompt, Options options, OllamaStreamHandler streamHandler) String model, String prompt, Options options, OllamaGenerateTokenHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
boolean raw = true; boolean raw = true;
OllamaToolsResult toolResult = new OllamaToolsResult(); OllamaToolsResult toolResult = new OllamaToolsResult();
@ -782,7 +727,14 @@ public class OllamaAPI {
prompt = promptBuilder.build(); prompt = promptBuilder.build();
} }
OllamaResult result = generate(model, prompt, raw, options, streamHandler); OllamaResult result =
generate(
model,
prompt,
raw,
false,
options,
new OllamaGenerateStreamObserver(null, streamHandler));
toolResult.setModelResult(result); toolResult.setModelResult(result);
String toolsResponse = result.getResponse(); String toolsResponse = result.getResponse();
@ -898,7 +850,7 @@ public class OllamaAPI {
List<Object> images, List<Object> images,
Options options, Options options,
Map<String, Object> format, Map<String, Object> format,
OllamaStreamHandler streamHandler) OllamaGenerateTokenHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
List<String> encodedImages = new ArrayList<>(); List<String> encodedImages = new ArrayList<>();
for (Object image : images) { for (Object image : images) {
@ -947,7 +899,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaChatResult chat(OllamaChatRequest request, OllamaTokenHandler tokenHandler) public OllamaChatResult chat(OllamaChatRequest request, OllamaChatTokenHandler tokenHandler)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
OllamaChatEndpointCaller requestCaller = OllamaChatEndpointCaller requestCaller =
new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds); new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds);
@ -1233,8 +1185,8 @@ public class OllamaAPI {
*/ */
private OllamaResult generateSyncForOllamaRequestModel( private OllamaResult generateSyncForOllamaRequestModel(
OllamaGenerateRequest ollamaRequestModel, OllamaGenerateRequest ollamaRequestModel,
OllamaStreamHandler thinkingStreamHandler, OllamaGenerateTokenHandler thinkingStreamHandler,
OllamaStreamHandler responseStreamHandler) OllamaGenerateTokenHandler responseStreamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = OllamaGenerateEndpointCaller requestCaller =
new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds); new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds);

View File

@ -0,0 +1,18 @@
/*
* Ollama4j - Java library for interacting with Ollama server.
* Copyright (c) 2025 Amith Koujalgi and contributors.
*
* Licensed under the MIT License (the "License");
* you may not use this file except in compliance with the License.
*
*/
package io.github.ollama4j.impl;
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
public final class ConsoleOutputChatTokenHandler extends OllamaChatStreamObserver {
public ConsoleOutputChatTokenHandler() {
setThinkingStreamHandler(new ConsoleOutputGenerateTokenHandler());
setResponseStreamHandler(new ConsoleOutputGenerateTokenHandler());
}
}

View File

@ -8,15 +8,11 @@
*/ */
package io.github.ollama4j.impl; package io.github.ollama4j.impl;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaGenerateTokenHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ConsoleOutputStreamHandler implements OllamaStreamHandler {
private static final Logger LOG = LoggerFactory.getLogger(ConsoleOutputStreamHandler.class);
public class ConsoleOutputGenerateTokenHandler implements OllamaGenerateTokenHandler {
@Override @Override
public void accept(String message) { public void accept(String message) {
LOG.info(message); System.out.print(message);
} }
} }

View File

@ -21,7 +21,9 @@ import lombok.*;
/** /**
* Defines a single Message to be used inside a chat request against the ollama /api/chat endpoint. * Defines a single Message to be used inside a chat request against the ollama /api/chat endpoint.
* *
* @see <a href="https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate chat completion</a> * @see <a
* href="https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate
* chat completion</a>
*/ */
@Data @Data
@AllArgsConstructor @AllArgsConstructor
@ -32,7 +34,9 @@ public class OllamaChatMessage {
@NonNull private OllamaChatMessageRole role; @NonNull private OllamaChatMessageRole role;
@NonNull private String content; @JsonProperty("content")
@NonNull
private String response;
private String thinking; private String thinking;

View File

@ -114,11 +114,18 @@ public class OllamaChatRequestBuilder {
imageURLConnectTimeoutSeconds, imageURLConnectTimeoutSeconds,
imageURLReadTimeoutSeconds)); imageURLReadTimeoutSeconds));
} catch (InterruptedException e) { } catch (InterruptedException e) {
LOG.error("Failed to load image from URL: {}. Cause: {}", imageUrl, e); LOG.error("Failed to load image from URL: '{}'. Cause: {}", imageUrl, e);
throw e; Thread.currentThread().interrupt();
throw new InterruptedException(
"Interrupted while loading image from URL: " + imageUrl);
} catch (IOException e) { } catch (IOException e) {
LOG.warn("Failed to load image from URL: {}. Cause: {}", imageUrl, e); LOG.error(
throw e; "IOException occurred while loading image from URL '{}'. Cause: {}",
imageUrl,
e.getMessage(),
e);
throw new IOException(
"IOException while loading image from URL: " + imageUrl, e);
} }
} }
} }

View File

@ -8,18 +8,18 @@
*/ */
package io.github.ollama4j.models.chat; package io.github.ollama4j.models.chat;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List; import java.util.List;
import lombok.Data; import lombok.Data;
@Data @Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaChatResponseModel { public class OllamaChatResponseModel {
private String model; private String model;
private @JsonProperty("created_at") String createdAt; private @JsonProperty("created_at") String createdAt;
private @JsonProperty("done_reason") String doneReason; private @JsonProperty("done_reason") String doneReason;
private OllamaChatMessage message;
private boolean done; private boolean done;
private String error;
private List<Integer> context; private List<Integer> context;
private @JsonProperty("total_duration") Long totalDuration; private @JsonProperty("total_duration") Long totalDuration;
private @JsonProperty("load_duration") Long loadDuration; private @JsonProperty("load_duration") Long loadDuration;
@ -27,4 +27,6 @@ public class OllamaChatResponseModel {
private @JsonProperty("eval_duration") Long evalDuration; private @JsonProperty("eval_duration") Long evalDuration;
private @JsonProperty("prompt_eval_count") Integer promptEvalCount; private @JsonProperty("prompt_eval_count") Integer promptEvalCount;
private @JsonProperty("eval_count") Integer evalCount; private @JsonProperty("eval_count") Integer evalCount;
private String error;
private OllamaChatMessage message;
} }

View File

@ -47,7 +47,7 @@ public class OllamaChatResult {
@Deprecated @Deprecated
public String getResponse() { public String getResponse() {
return responseModel != null ? responseModel.getMessage().getContent() : ""; return responseModel != null ? responseModel.getMessage().getResponse() : "";
} }
@Deprecated @Deprecated

View File

@ -8,16 +8,17 @@
*/ */
package io.github.ollama4j.models.chat; package io.github.ollama4j.models.chat;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaGenerateTokenHandler;
import io.github.ollama4j.models.generate.OllamaTokenHandler; import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor; import lombok.NoArgsConstructor;
import lombok.Setter;
@RequiredArgsConstructor @Setter
public class OllamaChatStreamObserver implements OllamaTokenHandler { @NoArgsConstructor
private final OllamaStreamHandler thinkingStreamHandler; @AllArgsConstructor
private final OllamaStreamHandler responseStreamHandler; public class OllamaChatStreamObserver implements OllamaChatTokenHandler {
private OllamaGenerateTokenHandler thinkingStreamHandler;
private String message = ""; private OllamaGenerateTokenHandler responseStreamHandler;
@Override @Override
public void accept(OllamaChatResponseModel token) { public void accept(OllamaChatResponseModel token) {
@ -26,33 +27,19 @@ public class OllamaChatStreamObserver implements OllamaTokenHandler {
} }
String thinking = token.getMessage().getThinking(); String thinking = token.getMessage().getThinking();
String content = token.getMessage().getContent(); String response = token.getMessage().getResponse();
boolean hasThinking = thinking != null && !thinking.isEmpty(); boolean hasThinking = thinking != null && !thinking.isEmpty();
boolean hasContent = !content.isEmpty(); boolean hasResponse = response != null && !response.isEmpty();
// if (hasThinking && !hasContent) {
//// message += thinking;
// message = thinking;
// } else {
//// message += content;
// message = content;
// }
//
// responseStreamHandler.accept(message);
if (!hasContent && hasThinking && thinkingStreamHandler != null) {
// message = message + thinking;
if (!hasResponse && hasThinking && thinkingStreamHandler != null) {
// use only new tokens received, instead of appending the tokens to the previous // use only new tokens received, instead of appending the tokens to the previous
// ones and sending the full string again // ones and sending the full string again
thinkingStreamHandler.accept(thinking); thinkingStreamHandler.accept(thinking);
} else if (hasContent && responseStreamHandler != null) { } else if (hasResponse) {
// message = message + response;
// use only new tokens received, instead of appending the tokens to the previous // use only new tokens received, instead of appending the tokens to the previous
// ones and sending the full string again // ones and sending the full string again
responseStreamHandler.accept(content); responseStreamHandler.accept(response);
} }
} }
} }

View File

@ -6,9 +6,8 @@
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* *
*/ */
package io.github.ollama4j.models.generate; package io.github.ollama4j.models.chat;
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
import java.util.function.Consumer; import java.util.function.Consumer;
public interface OllamaTokenHandler extends Consumer<OllamaChatResponseModel> {} public interface OllamaChatTokenHandler extends Consumer<OllamaChatResponseModel> {}

View File

@ -18,15 +18,15 @@ import lombok.Data;
public class OllamaGenerateResponseModel { public class OllamaGenerateResponseModel {
private String model; private String model;
private @JsonProperty("created_at") String createdAt; private @JsonProperty("created_at") String createdAt;
private String response;
private String thinking;
private boolean done;
private @JsonProperty("done_reason") String doneReason; private @JsonProperty("done_reason") String doneReason;
private boolean done;
private List<Integer> context; private List<Integer> context;
private @JsonProperty("total_duration") Long totalDuration; private @JsonProperty("total_duration") Long totalDuration;
private @JsonProperty("load_duration") Long loadDuration; private @JsonProperty("load_duration") Long loadDuration;
private @JsonProperty("prompt_eval_count") Integer promptEvalCount;
private @JsonProperty("prompt_eval_duration") Long promptEvalDuration; private @JsonProperty("prompt_eval_duration") Long promptEvalDuration;
private @JsonProperty("eval_count") Integer evalCount;
private @JsonProperty("eval_duration") Long evalDuration; private @JsonProperty("eval_duration") Long evalDuration;
private @JsonProperty("prompt_eval_count") Integer promptEvalCount;
private @JsonProperty("eval_count") Integer evalCount;
private String response;
private String thinking;
} }

View File

@ -10,18 +10,18 @@ package io.github.ollama4j.models.generate;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import lombok.Getter;
@Getter
public class OllamaGenerateStreamObserver { public class OllamaGenerateStreamObserver {
private final OllamaGenerateTokenHandler thinkingStreamHandler;
private final OllamaStreamHandler thinkingStreamHandler; private final OllamaGenerateTokenHandler responseStreamHandler;
private final OllamaStreamHandler responseStreamHandler;
private final List<OllamaGenerateResponseModel> responseParts = new ArrayList<>(); private final List<OllamaGenerateResponseModel> responseParts = new ArrayList<>();
private String message = "";
public OllamaGenerateStreamObserver( public OllamaGenerateStreamObserver(
OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) { OllamaGenerateTokenHandler thinkingStreamHandler,
OllamaGenerateTokenHandler responseStreamHandler) {
this.responseStreamHandler = responseStreamHandler; this.responseStreamHandler = responseStreamHandler;
this.thinkingStreamHandler = thinkingStreamHandler; this.thinkingStreamHandler = thinkingStreamHandler;
} }
@ -39,14 +39,10 @@ public class OllamaGenerateStreamObserver {
boolean hasThinking = thinking != null && !thinking.isEmpty(); boolean hasThinking = thinking != null && !thinking.isEmpty();
if (!hasResponse && hasThinking && thinkingStreamHandler != null) { if (!hasResponse && hasThinking && thinkingStreamHandler != null) {
// message = message + thinking;
// use only new tokens received, instead of appending the tokens to the previous // use only new tokens received, instead of appending the tokens to the previous
// ones and sending the full string again // ones and sending the full string again
thinkingStreamHandler.accept(thinking); thinkingStreamHandler.accept(thinking);
} else if (hasResponse && responseStreamHandler != null) { } else if (hasResponse && responseStreamHandler != null) {
// message = message + response;
// use only new tokens received, instead of appending the tokens to the previous // use only new tokens received, instead of appending the tokens to the previous
// ones and sending the full string again // ones and sending the full string again
responseStreamHandler.accept(response); responseStreamHandler.accept(response);

View File

@ -10,6 +10,6 @@ package io.github.ollama4j.models.generate;
import java.util.function.Consumer; import java.util.function.Consumer;
public interface OllamaStreamHandler extends Consumer<String> { public interface OllamaGenerateTokenHandler extends Consumer<String> {
void accept(String message); void accept(String message);
} }

View File

@ -12,7 +12,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.generate.OllamaTokenHandler; import io.github.ollama4j.models.chat.OllamaChatTokenHandler;
import io.github.ollama4j.models.response.OllamaErrorResponse; import io.github.ollama4j.models.response.OllamaErrorResponse;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import java.io.BufferedReader; import java.io.BufferedReader;
@ -36,7 +36,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
private OllamaTokenHandler tokenHandler; private OllamaChatTokenHandler tokenHandler;
public OllamaChatEndpointCaller(String host, Auth auth, long requestTimeoutSeconds) { public OllamaChatEndpointCaller(String host, Auth auth, long requestTimeoutSeconds) {
super(host, auth, requestTimeoutSeconds); super(host, auth, requestTimeoutSeconds);
@ -73,7 +73,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
if (message.getThinking() != null) { if (message.getThinking() != null) {
thinkingBuffer.append(message.getThinking()); thinkingBuffer.append(message.getThinking());
} else { } else {
responseBuffer.append(message.getContent()); responseBuffer.append(message.getResponse());
} }
if (tokenHandler != null) { if (tokenHandler != null) {
tokenHandler.accept(ollamaResponseModel); tokenHandler.accept(ollamaResponseModel);
@ -86,7 +86,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
} }
} }
public OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler) public OllamaChatResult call(OllamaChatRequest body, OllamaChatTokenHandler tokenHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
this.tokenHandler = tokenHandler; this.tokenHandler = tokenHandler;
return callSync(body); return callSync(body);
@ -127,7 +127,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls(); wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
} }
if (finished && body.stream) { if (finished && body.stream) {
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString()); ollamaChatResponseModel.getMessage().setResponse(responseBuffer.toString());
ollamaChatResponseModel.getMessage().setThinking(thinkingBuffer.toString()); ollamaChatResponseModel.getMessage().setThinking(thinkingBuffer.toString());
break; break;
} }

View File

@ -12,7 +12,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver; import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaGenerateTokenHandler;
import io.github.ollama4j.models.response.OllamaErrorResponse; import io.github.ollama4j.models.response.OllamaErrorResponse;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.utils.OllamaRequestBody; import io.github.ollama4j.utils.OllamaRequestBody;
@ -69,8 +69,8 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
public OllamaResult call( public OllamaResult call(
OllamaRequestBody body, OllamaRequestBody body,
OllamaStreamHandler thinkingStreamHandler, OllamaGenerateTokenHandler thinkingStreamHandler,
OllamaStreamHandler responseStreamHandler) OllamaGenerateTokenHandler responseStreamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
responseStreamObserver = responseStreamObserver =
new OllamaGenerateStreamObserver(thinkingStreamHandler, responseStreamHandler); new OllamaGenerateStreamObserver(thinkingStreamHandler, responseStreamHandler);

View File

@ -13,9 +13,12 @@ import static org.junit.jupiter.api.Assertions.*;
import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.exceptions.ToolInvocationException; import io.github.ollama4j.exceptions.ToolInvocationException;
import io.github.ollama4j.impl.ConsoleOutputChatTokenHandler;
import io.github.ollama4j.impl.ConsoleOutputGenerateTokenHandler;
import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
import io.github.ollama4j.models.response.Model; import io.github.ollama4j.models.response.Model;
import io.github.ollama4j.models.response.ModelDetail; import io.github.ollama4j.models.response.ModelDetail;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
@ -56,10 +59,41 @@ class OllamaAPIIntegrationTest {
@BeforeAll @BeforeAll
static void setUp() { static void setUp() {
int requestTimeoutSeconds = 60;
int numberOfRetriesForModelPull = 5;
try { try {
boolean useExternalOllamaHost = // Try to get from env vars first
Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST")); String useExternalOllamaHostEnv = System.getenv("USE_EXTERNAL_OLLAMA_HOST");
String ollamaHost = System.getenv("OLLAMA_HOST"); String ollamaHostEnv = System.getenv("OLLAMA_HOST");
boolean useExternalOllamaHost;
String ollamaHost;
if (useExternalOllamaHostEnv == null && ollamaHostEnv == null) {
// Fallback to test-config.properties from classpath
Properties props = new Properties();
try {
props.load(
OllamaAPIIntegrationTest.class
.getClassLoader()
.getResourceAsStream("test-config.properties"));
} catch (Exception e) {
throw new RuntimeException(
"Could not load test-config.properties from classpath", e);
}
useExternalOllamaHost =
Boolean.parseBoolean(
props.getProperty("USE_EXTERNAL_OLLAMA_HOST", "false"));
ollamaHost = props.getProperty("OLLAMA_HOST");
requestTimeoutSeconds =
Integer.parseInt(props.getProperty("REQUEST_TIMEOUT_SECONDS"));
numberOfRetriesForModelPull =
Integer.parseInt(props.getProperty("NUMBER_RETRIES_FOR_MODEL_PULL"));
} else {
useExternalOllamaHost = Boolean.parseBoolean(useExternalOllamaHostEnv);
ollamaHost = ollamaHostEnv;
}
if (useExternalOllamaHost) { if (useExternalOllamaHost) {
LOG.info("Using external Ollama host..."); LOG.info("Using external Ollama host...");
@ -90,8 +124,8 @@ class OllamaAPIIntegrationTest {
+ ":" + ":"
+ ollama.getMappedPort(internalPort)); + ollama.getMappedPort(internalPort));
} }
api.setRequestTimeoutSeconds(120); api.setRequestTimeoutSeconds(requestTimeoutSeconds);
api.setNumberOfRetriesForModelPull(5); api.setNumberOfRetriesForModelPull(numberOfRetriesForModelPull);
} }
@Test @Test
@ -187,7 +221,7 @@ class OllamaAPIIntegrationTest {
}); });
format.put("required", List.of("isNoon")); format.put("required", List.of("isNoon"));
OllamaResult result = api.generate(TOOLS_MODEL, prompt, format); OllamaResult result = api.generateWithFormat(TOOLS_MODEL, prompt, format);
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
@ -210,7 +244,8 @@ class OllamaAPIIntegrationTest {
+ " Lisa?", + " Lisa?",
raw, raw,
thinking, thinking,
new OptionsBuilder().build()); new OptionsBuilder().build(),
new OllamaGenerateStreamObserver(null, null));
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -228,8 +263,10 @@ class OllamaAPIIntegrationTest {
"What is the capital of France? And what's France's connection with Mona" "What is the capital of France? And what's France's connection with Mona"
+ " Lisa?", + " Lisa?",
raw, raw,
false,
new OptionsBuilder().build(), new OptionsBuilder().build(),
LOG::info); new OllamaGenerateStreamObserver(
null, new ConsoleOutputGenerateTokenHandler()));
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
@ -263,7 +300,7 @@ class OllamaAPIIntegrationTest {
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertFalse(chatResult.getResponseModel().getMessage().getContent().isEmpty()); assertFalse(chatResult.getResponseModel().getMessage().getResponse().isEmpty());
} }
@Test @Test
@ -296,9 +333,13 @@ class OllamaAPIIntegrationTest {
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank()); assertFalse(chatResult.getResponseModel().getMessage().getResponse().isBlank());
assertTrue( assertTrue(
chatResult.getResponseModel().getMessage().getContent().contains(expectedResponse)); chatResult
.getResponseModel()
.getMessage()
.getResponse()
.contains(expectedResponse));
assertEquals(3, chatResult.getChatHistory().size()); assertEquals(3, chatResult.getChatHistory().size());
} }
@ -515,16 +556,7 @@ class OllamaAPIIntegrationTest {
.withOptions(new OptionsBuilder().setTemperature(0.9f).build()) .withOptions(new OptionsBuilder().setTemperature(0.9f).build())
.build(); .build();
OllamaChatResult chatResult = OllamaChatResult chatResult = api.chat(requestModel, new ConsoleOutputChatTokenHandler());
api.chat(
requestModel,
new OllamaChatStreamObserver(
s -> {
LOG.info(s.toUpperCase());
},
s -> {
LOG.info(s.toLowerCase());
}));
assertNotNull(chatResult, "chatResult should not be null"); assertNotNull(chatResult, "chatResult should not be null");
assertNotNull(chatResult.getResponseModel(), "Response model should not be null"); assertNotNull(chatResult.getResponseModel(), "Response model should not be null");
@ -670,20 +702,11 @@ class OllamaAPIIntegrationTest {
.build(); .build();
requestModel.setThink(false); requestModel.setThink(false);
OllamaChatResult chatResult = OllamaChatResult chatResult = api.chat(requestModel, new ConsoleOutputChatTokenHandler());
api.chat(
requestModel,
new OllamaChatStreamObserver(
s -> {
LOG.info(s.toUpperCase());
},
s -> {
LOG.info(s.toLowerCase());
}));
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent()); assertNotNull(chatResult.getResponseModel().getMessage().getResponse());
} }
@Test @Test
@ -706,21 +729,12 @@ class OllamaAPIIntegrationTest {
.withKeepAlive("0m") .withKeepAlive("0m")
.build(); .build();
OllamaChatResult chatResult = OllamaChatResult chatResult = api.chat(requestModel, new ConsoleOutputChatTokenHandler());
api.chat(
requestModel,
new OllamaChatStreamObserver(
s -> {
LOG.info(s.toUpperCase());
},
s -> {
LOG.info(s.toLowerCase());
}));
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent()); assertNotNull(chatResult.getResponseModel().getMessage().getResponse());
} }
@Test @Test
@ -859,7 +873,8 @@ class OllamaAPIIntegrationTest {
"Who are you?", "Who are you?",
raw, raw,
think, think,
new OptionsBuilder().build()); new OptionsBuilder().build(),
new OllamaGenerateStreamObserver(null, null));
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertNotNull(result.getThinking()); assertNotNull(result.getThinking());
@ -876,13 +891,15 @@ class OllamaAPIIntegrationTest {
THINKING_TOOL_MODEL, THINKING_TOOL_MODEL,
"Who are you?", "Who are you?",
raw, raw,
true,
new OptionsBuilder().build(), new OptionsBuilder().build(),
new OllamaGenerateStreamObserver(
thinkingToken -> { thinkingToken -> {
LOG.info(thinkingToken.toUpperCase()); LOG.info(thinkingToken.toUpperCase());
}, },
resToken -> { resToken -> {
LOG.info(resToken.toLowerCase()); LOG.info(resToken.toLowerCase());
}); }));
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertNotNull(result.getThinking()); assertNotNull(result.getThinking());

View File

@ -203,7 +203,7 @@ public class WithAuth {
}); });
format.put("required", List.of("isNoon")); format.put("required", List.of("isNoon"));
OllamaResult result = api.generate(model, prompt, format); OllamaResult result = api.generateWithFormat(model, prompt, format);
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());

View File

@ -18,6 +18,7 @@ import io.github.ollama4j.exceptions.RoleNotFoundException;
import io.github.ollama4j.models.chat.OllamaChatMessageRole; import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
import io.github.ollama4j.models.request.CustomModelRequest; import io.github.ollama4j.models.request.CustomModelRequest;
import io.github.ollama4j.models.response.ModelDetail; import io.github.ollama4j.models.response.ModelDetail;
import io.github.ollama4j.models.response.OllamaAsyncResultStreamer; import io.github.ollama4j.models.response.OllamaAsyncResultStreamer;
@ -170,12 +171,13 @@ class TestMockedAPIs {
String model = "llama2"; String model = "llama2";
String prompt = "some prompt text"; String prompt = "some prompt text";
OptionsBuilder optionsBuilder = new OptionsBuilder(); OptionsBuilder optionsBuilder = new OptionsBuilder();
OllamaGenerateStreamObserver observer = new OllamaGenerateStreamObserver(null, null);
try { try {
when(ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build())) when(ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build(), observer))
.thenReturn(new OllamaResult("", "", 0, 200)); .thenReturn(new OllamaResult("", "", 0, 200));
ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build()); ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build(), observer);
verify(ollamaAPI, times(1)) verify(ollamaAPI, times(1))
.generate(model, prompt, false, false, optionsBuilder.build()); .generate(model, prompt, false, false, optionsBuilder.build(), observer);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }

View File

@ -59,6 +59,6 @@ class TestOllamaChatRequestBuilder {
assertNotNull(req.getMessages()); assertNotNull(req.getMessages());
assert (!req.getMessages().isEmpty()); assert (!req.getMessages().isEmpty());
OllamaChatMessage msg = req.getMessages().get(0); OllamaChatMessage msg = req.getMessages().get(0);
assertNotNull(msg.getContent()); assertNotNull(msg.getResponse());
} }
} }

View File

@ -67,12 +67,9 @@ class TestOptionsAndUtils {
@Test @Test
void testOptionsBuilderRejectsUnsupportedCustomType() { void testOptionsBuilderRejectsUnsupportedCustomType() {
assertThrows(
IllegalArgumentException.class,
() -> {
OptionsBuilder builder = new OptionsBuilder(); OptionsBuilder builder = new OptionsBuilder();
builder.setCustomOption("bad", new Object()); assertThrows(
}); IllegalArgumentException.class, () -> builder.setCustomOption("bad", new Object()));
} }
@Test @Test

View File

@ -10,6 +10,14 @@
<appender-ref ref="STDOUT"/> <appender-ref ref="STDOUT"/>
</root> </root>
<!-- Suppress logs from com.github.dockerjava package -->
<logger name="com.github.dockerjava" level="INFO"/>
<!-- Suppress logs from org.testcontainers package -->
<logger name="org.testcontainers" level="INFO"/>
<!-- Keep other loggers at WARN level -->
<logger name="org.apache" level="WARN"/> <logger name="org.apache" level="WARN"/>
<logger name="httpclient" level="WARN"/> <logger name="httpclient" level="WARN"/>
</configuration> </configuration>

View File

@ -1,4 +1,4 @@
ollama.url=http://localhost:11434 USE_EXTERNAL_OLLAMA_HOST=true
ollama.model=llama3.2:1b OLLAMA_HOST=http://192.168.29.229:11434/
ollama.model.image=llava:latest REQUEST_TIMEOUT_SECONDS=120
ollama.request-timeout-seconds=120 NUMBER_RETRIES_FOR_MODEL_PULL=3