+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+
+ org.projectlombok
+ lombok
+ ${lombok.version}
+
+
+
+
org.apache.maven.plugins
maven-source-plugin
@@ -146,7 +160,7 @@
yyyy-MM-dd'T'HH:mm:ss'Z'
- Etc/UTC
+ Etc/UTC
@@ -412,4 +426,4 @@
-
+
\ No newline at end of file
diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java
index 5689faa..65831e1 100644
--- a/src/main/java/io/github/ollama4j/OllamaAPI.java
+++ b/src/main/java/io/github/ollama4j/OllamaAPI.java
@@ -22,6 +22,7 @@ import io.github.ollama4j.tools.*;
import io.github.ollama4j.tools.annotations.OllamaToolService;
import io.github.ollama4j.tools.annotations.ToolProperty;
import io.github.ollama4j.tools.annotations.ToolSpec;
+import io.github.ollama4j.utils.Constants;
import io.github.ollama4j.utils.Options;
import io.github.ollama4j.utils.Utils;
import lombok.Setter;
@@ -55,33 +56,54 @@ import java.util.stream.Collectors;
public class OllamaAPI {
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
+
private final String host;
+ private Auth auth;
+ private final ToolRegistry toolRegistry = new ToolRegistry();
+
/**
- * -- SETTER --
- * Set request timeout in seconds. Default is 3 seconds.
+ * The request timeout in seconds for API calls.
+ *
+ * Default is 10 seconds. This value determines how long the client will wait
+ * for a response
+ * from the Ollama server before timing out.
*/
@Setter
private long requestTimeoutSeconds = 10;
+
/**
- * -- SETTER --
- * Set/unset logging of responses
+ * Enables or disables verbose logging of responses.
+ *
+ * If set to {@code true}, the API will log detailed information about requests
+ * and responses.
+ * Default is {@code true}.
*/
@Setter
private boolean verbose = true;
+ /**
+ * The maximum number of retries for tool calls during chat interactions.
+ *
+ * This value controls how many times the API will attempt to call a tool in the
+ * event of a failure.
+ * Default is 3.
+ */
@Setter
private int maxChatToolCallRetries = 3;
- private Auth auth;
-
+ /**
+ * The number of retries to attempt when pulling a model from the Ollama server.
+ *
+ * If set to 0, no retries will be performed. If greater than 0, the API will
+ * retry pulling the model
+ * up to the specified number of times in case of failure.
+ *
+ * Default is 0 (no retries).
+ */
+ @Setter
+ @SuppressWarnings({"FieldMayBeFinal", "FieldCanBeLocal"})
private int numberOfRetriesForModelPull = 0;
- public void setNumberOfRetriesForModelPull(int numberOfRetriesForModelPull) {
- this.numberOfRetriesForModelPull = numberOfRetriesForModelPull;
- }
-
- private final ToolRegistry toolRegistry = new ToolRegistry();
-
/**
* Instantiates the Ollama API with default Ollama host:
* http://localhost:11434
@@ -102,7 +124,7 @@ public class OllamaAPI {
this.host = host;
}
if (this.verbose) {
- logger.info("Ollama API initialized with host: " + this.host);
+ logger.info("Ollama API initialized with host: {}", this.host);
}
}
@@ -135,14 +157,17 @@ public class OllamaAPI {
public boolean ping() {
String url = this.host + "/api/tags";
HttpClient httpClient = HttpClient.newHttpClient();
- HttpRequest httpRequest = null;
+ HttpRequest httpRequest;
try {
- httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
- .header("Content-type", "application/json").GET().build();
+ httpRequest = getRequestBuilderDefault(new URI(url))
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
+ .GET()
+ .build();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
- HttpResponse response = null;
+ HttpResponse response;
try {
response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
} catch (HttpConnectTimeoutException e) {
@@ -168,8 +193,10 @@ public class OllamaAPI {
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = null;
try {
- httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
- .header("Content-type", "application/json").GET().build();
+ httpRequest = getRequestBuilderDefault(new URI(url))
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
+ .GET().build();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
@@ -196,8 +223,10 @@ public class OllamaAPI {
public List listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = this.host + "/api/tags";
HttpClient httpClient = HttpClient.newHttpClient();
- HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
- .header("Content-type", "application/json").GET().build();
+ HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
+ .build();
HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
@@ -229,8 +258,10 @@ public class OllamaAPI {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = "https://ollama.com/library";
HttpClient httpClient = HttpClient.newHttpClient();
- HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
- .header("Content-type", "application/json").GET().build();
+ HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
+ .build();
HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
@@ -296,8 +327,10 @@ public class OllamaAPI {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName());
HttpClient httpClient = HttpClient.newHttpClient();
- HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
- .header("Content-type", "application/json").GET().build();
+ HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
+ .build();
HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
@@ -338,6 +371,14 @@ public class OllamaAPI {
/**
* Finds a specific model using model name and tag from Ollama library.
*
+ * Deprecated: This method relies on the HTML structure of the Ollama
+ * website,
+ * which is subject to change at any time. As a result, it is difficult to keep
+ * this API
+ * method consistently updated and reliable. Therefore, this method is
+ * deprecated and
+ * may be removed in future releases.
+ *
* This method retrieves the model from the Ollama library by its name, then
* fetches its tags.
* It searches through the tags of the model to find one that matches the
@@ -355,7 +396,11 @@ public class OllamaAPI {
* @throws URISyntaxException If there is an error with the URI syntax.
* @throws InterruptedException If the operation is interrupted.
* @throws NoSuchElementException If the model or the tag is not found.
+ * @deprecated This method relies on the HTML structure of the Ollama website,
+ * which can change at any time and break this API. It is deprecated
+ * and may be removed in the future.
*/
+ @Deprecated
public LibraryModelTag findModelTagFromLibrary(String modelName, String tag)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
List libraryModels = this.listModelsFromLibrary();
@@ -363,40 +408,71 @@ public class OllamaAPI {
.findFirst().orElseThrow(
() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName)));
LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel);
- LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream()
- .filter(tagName -> tagName.getTag().equals(tag)).findFirst()
+ return libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst()
.orElseThrow(() -> new NoSuchElementException(
String.format("Tag '%s' for model '%s' not found", tag, modelName)));
- return libraryModelTag;
}
/**
* Pull a model on the Ollama server from the list of available models.
+ *
+ * If {@code numberOfRetriesForModelPull} is greater than 0, this method will
+ * retry pulling the model
+ * up to the specified number of times if an {@link OllamaBaseException} occurs,
+ * using exponential backoff
+ * between retries (delay doubles after each failed attempt, starting at 1
+ * second).
+ *
+ * The backoff is only applied between retries, not after the final attempt.
*
* @param modelName the name of the model
- * @throws OllamaBaseException if the response indicates an error status
+ * @throws OllamaBaseException if the response indicates an error status or all
+ * retries fail
* @throws IOException if an I/O error occurs during the HTTP request
- * @throws InterruptedException if the operation is interrupted
+ * @throws InterruptedException if the operation is interrupted or the thread is
+ * interrupted during backoff
* @throws URISyntaxException if the URI for the request is malformed
*/
public void pullModel(String modelName)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
if (numberOfRetriesForModelPull == 0) {
this.doPullModel(modelName);
- } else {
- int numberOfRetries = 0;
- while (numberOfRetries < numberOfRetriesForModelPull) {
- try {
- this.doPullModel(modelName);
- return;
- } catch (OllamaBaseException e) {
- logger.error("Failed to pull model " + modelName + ", retrying...");
- numberOfRetries++;
- }
+ return;
+ }
+ int numberOfRetries = 0;
+ long baseDelayMillis = 3000L; // 1 second base delay
+ while (numberOfRetries < numberOfRetriesForModelPull) {
+ try {
+ this.doPullModel(modelName);
+ return;
+ } catch (OllamaBaseException e) {
+ handlePullRetry(modelName, numberOfRetries, numberOfRetriesForModelPull, baseDelayMillis);
+ numberOfRetries++;
}
- throw new OllamaBaseException(
- "Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries");
+ }
+ throw new OllamaBaseException(
+ "Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries");
+ }
+
+ /**
+ * Handles retry backoff for pullModel.
+ */
+ private void handlePullRetry(String modelName, int currentRetry, int maxRetries, long baseDelayMillis)
+ throws InterruptedException {
+ int attempt = currentRetry + 1;
+ if (attempt < maxRetries) {
+ long backoffMillis = baseDelayMillis * (1L << currentRetry);
+ logger.error("Failed to pull model {}, retrying in {}s... (attempt {}/{})",
+ modelName, backoffMillis / 1000, attempt, maxRetries);
+ try {
+ Thread.sleep(backoffMillis);
+ } catch (InterruptedException ie) {
+ Thread.currentThread().interrupt();
+ throw ie;
+ }
+ } else {
+ logger.error("Failed to pull model {} after {} attempts, no more retries.", modelName, maxRetries);
}
}
@@ -404,10 +480,9 @@ public class OllamaAPI {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String url = this.host + "/api/pull";
String jsonData = new ModelRequest(modelName).toString();
- HttpRequest request = getRequestBuilderDefault(new URI(url))
- .POST(HttpRequest.BodyPublishers.ofString(jsonData))
- .header("Accept", "application/json")
- .header("Content-type", "application/json")
+ HttpRequest request = getRequestBuilderDefault(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData))
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
@@ -428,7 +503,7 @@ public class OllamaAPI {
if (modelPullResponse.getStatus() != null) {
if (verbose) {
- logger.info(modelName + ": " + modelPullResponse.getStatus());
+ logger.info("{}: {}", modelName, modelPullResponse.getStatus());
}
// Check if status is "success" and set success flag to true.
if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) {
@@ -452,8 +527,10 @@ public class OllamaAPI {
public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/version";
HttpClient httpClient = HttpClient.newHttpClient();
- HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
- .header("Content-type", "application/json").GET().build();
+ HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
+ .build();
HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
@@ -498,8 +575,10 @@ public class OllamaAPI {
throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
String url = this.host + "/api/show";
String jsonData = new ModelRequest(modelName).toString();
- HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
- .header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
+ HttpRequest request = getRequestBuilderDefault(new URI(url))
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
+ .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
@@ -529,8 +608,9 @@ public class OllamaAPI {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create";
String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString();
- HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
- .header("Content-Type", "application/json")
+ HttpRequest request = getRequestBuilderDefault(new URI(url))
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString());
@@ -569,8 +649,9 @@ public class OllamaAPI {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create";
String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString();
- HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
- .header("Content-Type", "application/json")
+ HttpRequest request = getRequestBuilderDefault(new URI(url))
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString());
@@ -602,8 +683,9 @@ public class OllamaAPI {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create";
String jsonData = customModelRequest.toString();
- HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
- .header("Content-Type", "application/json")
+ HttpRequest request = getRequestBuilderDefault(new URI(url))
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString());
@@ -637,7 +719,9 @@ public class OllamaAPI {
String jsonData = new ModelRequest(modelName).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url))
.method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
- .header("Accept", "application/json").header("Content-type", "application/json").build();
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
+ .build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
@@ -683,7 +767,8 @@ public class OllamaAPI {
URI uri = URI.create(this.host + "/api/embeddings");
String jsonData = modelRequest.toString();
HttpClient httpClient = HttpClient.newHttpClient();
- HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json")
+ HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri)
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData));
HttpRequest request = requestBuilder.build();
HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
@@ -728,7 +813,8 @@ public class OllamaAPI {
String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
HttpClient httpClient = HttpClient.newHttpClient();
- HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json")
+ HttpRequest request = HttpRequest.newBuilder(uri)
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
@@ -744,33 +830,112 @@ public class OllamaAPI {
/**
* Generate response for a question to a model running on Ollama server. This is
- * a sync/blocking
- * call.
+ * a sync/blocking call. This API does not support "thinking" models.
*
- * @param model the ollama model to ask the question to
- * @param prompt the prompt/question text
- * @param options the Options object - More
- * details on the options
- * @param streamHandler optional callback consumer that will be applied every
- * time a streamed response is received. If not set, the
- * stream parameter of the request is set to false.
+ * @param model the ollama model to ask the question to
+ * @param prompt the prompt/question text
+ * @param raw if true no formatting will be applied to the
+ * prompt. You
+ * may choose to use the raw parameter if you are
+ * specifying a full templated prompt in your
+ * request to
+ * the API
+ * @param options the Options object - More
+ * details on the options
+ * @param responseStreamHandler optional callback consumer that will be applied
+ * every
+ * time a streamed response is received. If not
+ * set, the
+ * stream parameter of the request is set to false.
* @return OllamaResult that includes response text and time taken for response
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generate(String model, String prompt, boolean raw, Options options,
- OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
+ OllamaStreamHandler responseStreamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw);
+ ollamaRequestModel.setThink(false);
ollamaRequestModel.setOptions(options.getOptionsMap());
- return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
+ return generateSyncForOllamaRequestModel(ollamaRequestModel, null, responseStreamHandler);
+ }
+
+ /**
+ * Generate thinking and response tokens for a question to a thinking model
+ * running on Ollama server. This is
+ * a sync/blocking call.
+ *
+ * @param model the ollama model to ask the question to
+ * @param prompt the prompt/question text
+ * @param raw if true no formatting will be applied to the
+ * prompt. You
+ * may choose to use the raw parameter if you are
+ * specifying a full templated prompt in your
+ * request to
+ * the API
+ * @param options the Options object - More
+ * details on the options
+ * @param responseStreamHandler optional callback consumer that will be applied
+ * every
+ * time a streamed response is received. If not
+ * set, the
+ * stream parameter of the request is set to false.
+ * @return OllamaResult that includes response text and time taken for response
+ * @throws OllamaBaseException if the response indicates an error status
+ * @throws IOException if an I/O error occurs during the HTTP request
+ * @throws InterruptedException if the operation is interrupted
+ */
+ public OllamaResult generate(String model, String prompt, boolean raw, Options options,
+ OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler)
+ throws OllamaBaseException, IOException, InterruptedException {
+ OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
+ ollamaRequestModel.setRaw(raw);
+ ollamaRequestModel.setThink(true);
+ ollamaRequestModel.setOptions(options.getOptionsMap());
+ return generateSyncForOllamaRequestModel(ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
+ }
+
+ /**
+ * Generates response using the specified AI model and prompt (in blocking
+ * mode).
+ *
+ * Uses
+ * {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
+ *
+ * @param model The name or identifier of the AI model to use for generating
+ * the response.
+ * @param prompt The input text or prompt to provide to the AI model.
+ * @param raw In some cases, you may wish to bypass the templating system
+ * and provide a full prompt. In this case, you can use the raw
+ * parameter to disable templating. Also note that raw mode will
+ * not return a context.
+ * @param options Additional options or configurations to use when generating
+ * the response.
+ * @param think if true the model will "think" step-by-step before
+ * generating the final response
+ * @return {@link OllamaResult}
+ * @throws OllamaBaseException if the response indicates an error status
+ * @throws IOException if an I/O error occurs during the HTTP request
+ * @throws InterruptedException if the operation is interrupted
+ */
+ public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options)
+ throws OllamaBaseException, IOException, InterruptedException {
+ if (think) {
+ return generate(model, prompt, raw, options, null, null);
+ } else {
+ return generate(model, prompt, raw, options, null);
+ }
}
/**
* Generates structured output from the specified AI model and prompt.
+ *
+ * Note: When formatting is specified, the 'think' parameter is not allowed.
*
* @param model The name or identifier of the AI model to use for generating
* the response.
@@ -783,6 +948,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request.
* @throws InterruptedException if the operation is interrupted.
*/
+ @SuppressWarnings("LoggingSimilarMessage")
public OllamaResult generate(String model, String prompt, Map format)
throws OllamaBaseException, IOException, InterruptedException {
URI uri = URI.create(this.host + "/api/generate");
@@ -797,51 +963,52 @@ public class OllamaAPI {
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest request = getRequestBuilderDefault(uri)
- .header("Accept", "application/json")
- .header("Content-type", "application/json")
- .POST(HttpRequest.BodyPublishers.ofString(jsonData))
- .build();
+ .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
+ .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
+ if (verbose) {
+ try {
+ String prettyJson = Utils.getObjectMapper().writerWithDefaultPrettyPrinter()
+ .writeValueAsString(Utils.getObjectMapper().readValue(jsonData, Object.class));
+ logger.info("Asking model:\n{}", prettyJson);
+ } catch (Exception e) {
+ logger.info("Asking model: {}", jsonData);
+ }
+ }
HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseBody = response.body();
-
if (statusCode == 200) {
OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody,
OllamaStructuredResult.class);
- OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(),
+ OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(),
structuredResult.getResponseTime(), statusCode);
+
+ ollamaResult.setModel(structuredResult.getModel());
+ ollamaResult.setCreatedAt(structuredResult.getCreatedAt());
+ ollamaResult.setDone(structuredResult.isDone());
+ ollamaResult.setDoneReason(structuredResult.getDoneReason());
+ ollamaResult.setContext(structuredResult.getContext());
+ ollamaResult.setTotalDuration(structuredResult.getTotalDuration());
+ ollamaResult.setLoadDuration(structuredResult.getLoadDuration());
+ ollamaResult.setPromptEvalCount(structuredResult.getPromptEvalCount());
+ ollamaResult.setPromptEvalDuration(structuredResult.getPromptEvalDuration());
+ ollamaResult.setEvalCount(structuredResult.getEvalCount());
+ ollamaResult.setEvalDuration(structuredResult.getEvalDuration());
+ if (verbose) {
+ logger.info("Model response:\n{}", ollamaResult);
+ }
return ollamaResult;
} else {
+ if (verbose) {
+ logger.info("Model response:\n{}",
+ Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseBody));
+ }
throw new OllamaBaseException(statusCode + " - " + responseBody);
}
}
- /**
- * Generates response using the specified AI model and prompt (in blocking
- * mode).
- *
- * Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
- *
- * @param model The name or identifier of the AI model to use for generating
- * the response.
- * @param prompt The input text or prompt to provide to the AI model.
- * @param raw In some cases, you may wish to bypass the templating system
- * and provide a full prompt. In this case, you can use the raw
- * parameter to disable templating. Also note that raw mode will
- * not return a context.
- * @param options Additional options or configurations to use when generating
- * the response.
- * @return {@link OllamaResult}
- * @throws OllamaBaseException if the response indicates an error status
- * @throws IOException if an I/O error occurs during the HTTP request
- * @throws InterruptedException if the operation is interrupted
- */
- public OllamaResult generate(String model, String prompt, boolean raw, Options options)
- throws OllamaBaseException, IOException, InterruptedException {
- return generate(model, prompt, raw, options, null);
- }
-
/**
* Generates response using the specified AI model and prompt (in blocking
* mode), and then invokes a set of tools
@@ -893,8 +1060,7 @@ public class OllamaAPI {
logger.warn("Response from model does not contain any tool calls. Returning the response as is.");
return toolResult;
}
- toolFunctionCallSpecs = objectMapper.readValue(
- toolsResponse,
+ toolFunctionCallSpecs = objectMapper.readValue(toolsResponse,
objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
}
for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
@@ -905,19 +1071,47 @@ public class OllamaAPI {
}
/**
- * Generate response for a question to a model running on Ollama server and get
- * a callback handle
- * that can be used to check for status and get the response from the model
- * later. This would be
- * an async/non-blocking call.
+ * Asynchronously generates a response for a prompt using a model running on the
+ * Ollama server.
+ *
+ * This method returns an {@link OllamaAsyncResultStreamer} handle that can be
+ * used to poll for
+ * status and retrieve streamed "thinking" and response tokens from the model.
+ * The call is non-blocking.
+ *
*
- * @param model the ollama model to ask the question to
- * @param prompt the prompt/question text
- * @return the ollama async result callback handle
+ *
+ * Example usage:
+ *
+ *
+ * {@code
+ * OllamaAsyncResultStreamer resultStreamer = ollamaAPI.generateAsync("gpt-oss:20b", "Who are you", false, true);
+ * int pollIntervalMilliseconds = 1000;
+ * while (true) {
+ * String thinkingTokens = resultStreamer.getThinkingResponseStream().poll();
+ * String responseTokens = resultStreamer.getResponseStream().poll();
+ * System.out.print(thinkingTokens != null ? thinkingTokens.toUpperCase() : "");
+ * System.out.print(responseTokens != null ? responseTokens.toLowerCase() : "");
+ * Thread.sleep(pollIntervalMilliseconds);
+ * if (!resultStreamer.isAlive())
+ * break;
+ * }
+ * System.out.println("Complete thinking response: " + resultStreamer.getCompleteThinkingResponse());
+ * System.out.println("Complete response: " + resultStreamer.getCompleteResponse());
+ * }
+ *
+ * @param model the Ollama model to use for generating the response
+ * @param prompt the prompt or question text to send to the model
+ * @param raw if {@code true}, returns the raw response from the model
+ * @param think if {@code true}, streams "thinking" tokens as well as response
+ * tokens
+ * @return an {@link OllamaAsyncResultStreamer} handle for polling and
+ * retrieving streamed results
*/
- public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw) {
+ public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw, boolean think) {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw);
+ ollamaRequestModel.setThink(think);
URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(
getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
@@ -953,7 +1147,7 @@ public class OllamaAPI {
}
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap());
- return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
+ return generateSyncForOllamaRequestModel(ollamaRequestModel, null, streamHandler);
}
/**
@@ -1001,7 +1195,7 @@ public class OllamaAPI {
}
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap());
- return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
+ return generateSyncForOllamaRequestModel(ollamaRequestModel, null, streamHandler);
}
/**
@@ -1023,38 +1217,47 @@ public class OllamaAPI {
/**
* Synchronously generates a response using a list of image byte arrays.
*
- * This method encodes the provided byte arrays into Base64 and sends them to the Ollama server.
+ * This method encodes the provided byte arrays into Base64 and sends them to
+ * the Ollama server.
*
* @param model the Ollama model to use for generating the response
* @param prompt the prompt or question text to send to the model
* @param images the list of image data as byte arrays
- * @param options the Options object - More details on the options
- * @param streamHandler optional callback that will be invoked with each streamed response; if null, streaming is disabled
- * @return OllamaResult containing the response text and the time taken for the response
+ * @param options the Options object - More
+ * details on the options
+ * @param streamHandler optional callback that will be invoked with each
+ * streamed response; if null, streaming is disabled
+ * @return OllamaResult containing the response text and the time taken for the
+ * response
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
- public OllamaResult generateWithImages(String model, String prompt, List images, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
+ public OllamaResult generateWithImages(String model, String prompt, List images, Options options,
+ OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
List encodedImages = new ArrayList<>();
for (byte[] image : images) {
encodedImages.add(encodeByteArrayToBase64(image));
}
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, encodedImages);
ollamaRequestModel.setOptions(options.getOptionsMap());
- return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
+ return generateSyncForOllamaRequestModel(ollamaRequestModel, null, streamHandler);
}
/**
- * Convenience method to call the Ollama API using image byte arrays without streaming responses.
+ * Convenience method to call the Ollama API using image byte arrays without
+ * streaming responses.
*
- * Uses {@link #generateWithImages(String, String, List, Options, OllamaStreamHandler)}
+ * Uses
+ * {@link #generateWithImages(String, String, List, Options, OllamaStreamHandler)}
*
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
- public OllamaResult generateWithImages(String model, String prompt, List images, Options options) throws OllamaBaseException, IOException, InterruptedException {
+ public OllamaResult generateWithImages(String model, String prompt, List images, Options options)
+ throws OllamaBaseException, IOException, InterruptedException {
return generateWithImages(model, prompt, images, options, null);
}
@@ -1069,10 +1272,12 @@ public class OllamaAPI {
* history including the newly acquired assistant response.
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
- * @throws InterruptedException in case the server is not reachable or network
+ * @throws InterruptedException in case the server is not reachable or
+ * network
* issues happen
* @throws OllamaBaseException if the response indicates an error status
- * @throws IOException if an I/O error occurs during the HTTP request
+ * @throws IOException if an I/O error occurs during the HTTP
+ * request
* @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails
*/
@@ -1092,16 +1297,18 @@ public class OllamaAPI {
* @return {@link OllamaChatResult}
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
- * @throws InterruptedException in case the server is not reachable or network
+ * @throws InterruptedException in case the server is not reachable or
+ * network
* issues happen
* @throws OllamaBaseException if the response indicates an error status
- * @throws IOException if an I/O error occurs during the HTTP request
+ * @throws IOException if an I/O error occurs during the HTTP
+ * request
* @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails
*/
public OllamaChatResult chat(OllamaChatRequest request)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
- return chat(request, null);
+ return chat(request, null, null);
}
/**
@@ -1110,23 +1317,27 @@ public class OllamaAPI {
*
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
*
- * @param request request object to be sent to the server
- * @param streamHandler callback handler to handle the last message from stream
- * (caution: all previous tokens from stream will be
- * concatenated)
+ * @param request request object to be sent to the server
+ * @param responseStreamHandler callback handler to handle the last message from
+ * stream
+ * @param thinkingStreamHandler callback handler to handle the last thinking
+ * message from stream
* @return {@link OllamaChatResult}
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
- * @throws InterruptedException in case the server is not reachable or network
+ * @throws InterruptedException in case the server is not reachable or
+ * network
* issues happen
* @throws OllamaBaseException if the response indicates an error status
- * @throws IOException if an I/O error occurs during the HTTP request
+ * @throws IOException if an I/O error occurs during the HTTP
+ * request
* @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails
*/
- public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler)
+ public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler thinkingStreamHandler,
+ OllamaStreamHandler responseStreamHandler)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
- return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
+ return chatStreaming(request, new OllamaChatStreamObserver(thinkingStreamHandler, responseStreamHandler));
}
/**
@@ -1177,8 +1388,11 @@ public class OllamaAPI {
}
Map arguments = toolCall.getFunction().getArguments();
Object res = toolFunction.apply(arguments);
+ String argumentKeys = arguments.keySet().stream()
+ .map(Object::toString)
+ .collect(Collectors.joining(", "));
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,
- "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]"));
+ "[TOOL_RESULTS] " + toolName + "(" + argumentKeys + "): " + res + " [/TOOL_RESULTS]"));
}
if (tokenHandler != null) {
@@ -1224,6 +1438,17 @@ public class OllamaAPI {
}
}
+ /**
+ * Deregisters all tools from the tool registry.
+ * This method removes all registered tools, effectively clearing the registry.
+ */
+ public void deregisterTools() {
+ toolRegistry.clear();
+ if (this.verbose) {
+ logger.debug("All tools have been deregistered.");
+ }
+ }
+
/**
* Registers tools based on the annotations found on the methods of the caller's
* class and its providers.
@@ -1380,10 +1605,12 @@ public class OllamaAPI {
* the request will be streamed; otherwise, a regular synchronous request will
* be made.
*
- * @param ollamaRequestModel the request model containing necessary parameters
- * for the Ollama API request.
- * @param streamHandler the stream handler to process streaming responses,
- * or null for non-streaming requests.
+ * @param ollamaRequestModel the request model containing necessary
+ * parameters
+ * for the Ollama API request.
+ * @param responseStreamHandler the stream handler to process streaming
+ * responses,
+ * or null for non-streaming requests.
* @return the result of the Ollama API request.
* @throws OllamaBaseException if the request fails due to an issue with the
* Ollama API.
@@ -1392,13 +1619,14 @@ public class OllamaAPI {
* @throws InterruptedException if the thread is interrupted during the request.
*/
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel,
- OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
+ OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler)
+ throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds,
verbose);
OllamaResult result;
- if (streamHandler != null) {
+ if (responseStreamHandler != null) {
ollamaRequestModel.setStream(true);
- result = requestCaller.call(ollamaRequestModel, streamHandler);
+ result = requestCaller.call(ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
} else {
result = requestCaller.callSync(ollamaRequestModel);
}
@@ -1412,7 +1640,8 @@ public class OllamaAPI {
* @return HttpRequest.Builder
*/
private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
- HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json")
+ HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri)
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.timeout(Duration.ofSeconds(requestTimeoutSeconds));
if (isBasicAuthCredentialsSet()) {
requestBuilder.header("Authorization", auth.getAuthHeaderValue());
diff --git a/src/main/java/io/github/ollama4j/impl/ConsoleOutputStreamHandler.java b/src/main/java/io/github/ollama4j/impl/ConsoleOutputStreamHandler.java
index c9f8e36..d990006 100644
--- a/src/main/java/io/github/ollama4j/impl/ConsoleOutputStreamHandler.java
+++ b/src/main/java/io/github/ollama4j/impl/ConsoleOutputStreamHandler.java
@@ -3,12 +3,8 @@ package io.github.ollama4j.impl;
import io.github.ollama4j.models.generate.OllamaStreamHandler;
public class ConsoleOutputStreamHandler implements OllamaStreamHandler {
- private final StringBuffer response = new StringBuffer();
-
@Override
public void accept(String message) {
- String substr = message.substring(response.length());
- response.append(substr);
- System.out.print(substr);
+ System.out.print(message);
}
}
diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java
index 86b7726..e3d7912 100644
--- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java
+++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java
@@ -1,21 +1,15 @@
package io.github.ollama4j.models.chat;
-import static io.github.ollama4j.utils.Utils.getObjectMapper;
-
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
-import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
-
import io.github.ollama4j.utils.FileToBase64Serializer;
+import lombok.*;
import java.util.List;
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import lombok.NonNull;
-import lombok.RequiredArgsConstructor;
+import static io.github.ollama4j.utils.Utils.getObjectMapper;
/**
* Defines a single Message to be used inside a chat request against the ollama /api/chat endpoint.
@@ -35,6 +29,8 @@ public class OllamaChatMessage {
@NonNull
private String content;
+ private String thinking;
+
private @JsonProperty("tool_calls") List toolCalls;
@JsonSerialize(using = FileToBase64Serializer.class)
diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java
index 5d19703..7b19e02 100644
--- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java
+++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java
@@ -1,43 +1,46 @@
package io.github.ollama4j.models.chat;
-import java.util.List;
-
import io.github.ollama4j.models.request.OllamaCommonRequest;
import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.OllamaRequestBody;
-
import lombok.Getter;
import lombok.Setter;
+import java.util.List;
+
/**
* Defines a Request to use against the ollama /api/chat endpoint.
*
* @see Generate
- * Chat Completion
+ * "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate
+ * Chat Completion
*/
@Getter
@Setter
public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequestBody {
- private List messages;
+ private List messages;
- private List tools;
+ private List tools;
- public OllamaChatRequest() {}
+ private boolean think;
- public OllamaChatRequest(String model, List messages) {
- this.model = model;
- this.messages = messages;
- }
-
- @Override
- public boolean equals(Object o) {
- if (!(o instanceof OllamaChatRequest)) {
- return false;
+ public OllamaChatRequest() {
}
- return this.toString().equals(o.toString());
- }
+ public OllamaChatRequest(String model, boolean think, List messages) {
+ this.model = model;
+ this.messages = messages;
+ this.think = think;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof OllamaChatRequest)) {
+ return false;
+ }
+
+ return this.toString().equals(o.toString());
+ }
}
diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java
index 9094546..4a9caf9 100644
--- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java
+++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java
@@ -22,7 +22,7 @@ public class OllamaChatRequestBuilder {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class);
private OllamaChatRequestBuilder(String model, List messages) {
- request = new OllamaChatRequest(model, messages);
+ request = new OllamaChatRequest(model, false, messages);
}
private OllamaChatRequest request;
@@ -36,14 +36,20 @@ public class OllamaChatRequestBuilder {
}
public void reset() {
- request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
+ request = new OllamaChatRequest(request.getModel(), request.isThink(), new ArrayList<>());
}
- public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){
- return withMessage(role,content, Collections.emptyList());
+ public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content) {
+ return withMessage(role, content, Collections.emptyList());
}
- public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls,List images) {
+ public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls) {
+ List messages = this.request.getMessages();
+ messages.add(new OllamaChatMessage(role, content, null, toolCalls, null));
+ return this;
+ }
+
+ public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls, List images) {
List messages = this.request.getMessages();
List binaryImages = images.stream().map(file -> {
@@ -55,11 +61,11 @@ public class OllamaChatRequestBuilder {
}
}).collect(Collectors.toList());
- messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
+ messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages));
return this;
}
- public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List toolCalls, String... imageUrls) {
+ public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls, String... imageUrls) {
List messages = this.request.getMessages();
List binaryImages = null;
if (imageUrls.length > 0) {
@@ -75,7 +81,7 @@ public class OllamaChatRequestBuilder {
}
}
- messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
+ messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages));
return this;
}
@@ -108,4 +114,8 @@ public class OllamaChatRequestBuilder {
return this;
}
+ public OllamaChatRequestBuilder withThinking(boolean think) {
+ this.request.setThink(think);
+ return this;
+ }
}
diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java
index f8ebb05..5fbf7e3 100644
--- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java
+++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java
@@ -1,10 +1,10 @@
package io.github.ollama4j.models.chat;
-import java.util.List;
-
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.Getter;
+import java.util.List;
+
import static io.github.ollama4j.utils.Utils.getObjectMapper;
/**
diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java
index af181da..2ccdb74 100644
--- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java
+++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java
@@ -6,14 +6,46 @@ import lombok.RequiredArgsConstructor;
@RequiredArgsConstructor
public class OllamaChatStreamObserver implements OllamaTokenHandler {
- private final OllamaStreamHandler streamHandler;
+ private final OllamaStreamHandler thinkingStreamHandler;
+ private final OllamaStreamHandler responseStreamHandler;
+
private String message = "";
@Override
public void accept(OllamaChatResponseModel token) {
- if (streamHandler != null) {
- message += token.getMessage().getContent();
- streamHandler.accept(message);
+ if (responseStreamHandler == null || token == null || token.getMessage() == null) {
+ return;
+ }
+
+ String thinking = token.getMessage().getThinking();
+ String content = token.getMessage().getContent();
+
+ boolean hasThinking = thinking != null && !thinking.isEmpty();
+ boolean hasContent = !content.isEmpty();
+
+// if (hasThinking && !hasContent) {
+//// message += thinking;
+// message = thinking;
+// } else {
+//// message += content;
+// message = content;
+// }
+//
+// responseStreamHandler.accept(message);
+
+
+ if (!hasContent && hasThinking && thinkingStreamHandler != null) {
+ // message = message + thinking;
+
+ // use only new tokens received, instead of appending the tokens to the previous
+ // ones and sending the full string again
+ thinkingStreamHandler.accept(thinking);
+ } else if (hasContent && responseStreamHandler != null) {
+ // message = message + response;
+
+ // use only new tokens received, instead of appending the tokens to the previous
+ // ones and sending the full string again
+ responseStreamHandler.accept(content);
}
}
}
diff --git a/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingResponseModel.java b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingResponseModel.java
index dcf7b47..2d0d90a 100644
--- a/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingResponseModel.java
+++ b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingResponseModel.java
@@ -1,9 +1,9 @@
package io.github.ollama4j.models.embeddings;
import com.fasterxml.jackson.annotation.JsonProperty;
+import lombok.Data;
import java.util.List;
-import lombok.Data;
@SuppressWarnings("unused")
@Data
diff --git a/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingsRequestModel.java b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingsRequestModel.java
index d68624c..7d113f0 100644
--- a/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingsRequestModel.java
+++ b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingsRequestModel.java
@@ -1,7 +1,5 @@
package io.github.ollama4j.models.embeddings;
-import static io.github.ollama4j.utils.Utils.getObjectMapper;
-import java.util.Map;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.Data;
@@ -9,6 +7,10 @@ import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
+import java.util.Map;
+
+import static io.github.ollama4j.utils.Utils.getObjectMapper;
+
@Data
@RequiredArgsConstructor
@NoArgsConstructor
diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateRequest.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateRequest.java
index de767dc..3763f0a 100644
--- a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateRequest.java
+++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateRequest.java
@@ -3,12 +3,11 @@ package io.github.ollama4j.models.generate;
import io.github.ollama4j.models.request.OllamaCommonRequest;
import io.github.ollama4j.utils.OllamaRequestBody;
-
-import java.util.List;
-
import lombok.Getter;
import lombok.Setter;
+import java.util.List;
+
@Getter
@Setter
public class OllamaGenerateRequest extends OllamaCommonRequest implements OllamaRequestBody{
@@ -19,6 +18,7 @@ public class OllamaGenerateRequest extends OllamaCommonRequest implements Ollama
private String system;
private String context;
private boolean raw;
+ private boolean think;
public OllamaGenerateRequest() {
}
diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java
index 9fb975e..a3d23ec 100644
--- a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java
+++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java
@@ -2,9 +2,9 @@ package io.github.ollama4j.models.generate;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
+import lombok.Data;
import java.util.List;
-import lombok.Data;
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
@@ -12,12 +12,14 @@ public class OllamaGenerateResponseModel {
private String model;
private @JsonProperty("created_at") String createdAt;
private String response;
+ private String thinking;
private boolean done;
+ private @JsonProperty("done_reason") String doneReason;
private List context;
private @JsonProperty("total_duration") Long totalDuration;
private @JsonProperty("load_duration") Long loadDuration;
- private @JsonProperty("prompt_eval_duration") Long promptEvalDuration;
- private @JsonProperty("eval_duration") Long evalDuration;
private @JsonProperty("prompt_eval_count") Integer promptEvalCount;
+ private @JsonProperty("prompt_eval_duration") Long promptEvalDuration;
private @JsonProperty("eval_count") Integer evalCount;
+ private @JsonProperty("eval_duration") Long evalDuration;
}
diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java
index bc47fa0..67ae571 100644
--- a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java
+++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java
@@ -5,14 +5,16 @@ import java.util.List;
public class OllamaGenerateStreamObserver {
- private OllamaStreamHandler streamHandler;
+ private final OllamaStreamHandler thinkingStreamHandler;
+ private final OllamaStreamHandler responseStreamHandler;
- private List responseParts = new ArrayList<>();
+ private final List responseParts = new ArrayList<>();
private String message = "";
- public OllamaGenerateStreamObserver(OllamaStreamHandler streamHandler) {
- this.streamHandler = streamHandler;
+ public OllamaGenerateStreamObserver(OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) {
+ this.responseStreamHandler = responseStreamHandler;
+ this.thinkingStreamHandler = thinkingStreamHandler;
}
public void notify(OllamaGenerateResponseModel currentResponsePart) {
@@ -21,9 +23,24 @@ public class OllamaGenerateStreamObserver {
}
protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart) {
- message = message + currentResponsePart.getResponse();
- streamHandler.accept(message);
+ String response = currentResponsePart.getResponse();
+ String thinking = currentResponsePart.getThinking();
+
+ boolean hasResponse = response != null && !response.isEmpty();
+ boolean hasThinking = thinking != null && !thinking.isEmpty();
+
+ if (!hasResponse && hasThinking && thinkingStreamHandler != null) {
+ // message = message + thinking;
+
+ // use only new tokens received, instead of appending the tokens to the previous
+ // ones and sending the full string again
+ thinkingStreamHandler.accept(thinking);
+ } else if (hasResponse && responseStreamHandler != null) {
+ // message = message + response;
+
+ // use only new tokens received, instead of appending the tokens to the previous
+ // ones and sending the full string again
+ responseStreamHandler.accept(response);
+ }
}
-
-
}
diff --git a/src/main/java/io/github/ollama4j/models/request/BasicAuth.java b/src/main/java/io/github/ollama4j/models/request/BasicAuth.java
index c58b240..13f6a59 100644
--- a/src/main/java/io/github/ollama4j/models/request/BasicAuth.java
+++ b/src/main/java/io/github/ollama4j/models/request/BasicAuth.java
@@ -1,13 +1,14 @@
package io.github.ollama4j.models.request;
-import java.util.Base64;
-
import lombok.AllArgsConstructor;
import lombok.Data;
-import lombok.NoArgsConstructor;
+import lombok.EqualsAndHashCode;
+
+import java.util.Base64;
@Data
@AllArgsConstructor
+@EqualsAndHashCode(callSuper = false)
public class BasicAuth extends Auth {
private String username;
private String password;
diff --git a/src/main/java/io/github/ollama4j/models/request/BearerAuth.java b/src/main/java/io/github/ollama4j/models/request/BearerAuth.java
index 8236042..4d876f2 100644
--- a/src/main/java/io/github/ollama4j/models/request/BearerAuth.java
+++ b/src/main/java/io/github/ollama4j/models/request/BearerAuth.java
@@ -2,18 +2,20 @@ package io.github.ollama4j.models.request;
import lombok.AllArgsConstructor;
import lombok.Data;
+import lombok.EqualsAndHashCode;
@Data
@AllArgsConstructor
+@EqualsAndHashCode(callSuper = false)
public class BearerAuth extends Auth {
- private String bearerToken;
+ private String bearerToken;
- /**
- * Get authentication header value.
- *
- * @return authentication header value with bearer token
- */
- public String getAuthHeaderValue() {
- return "Bearer "+ bearerToken;
- }
+ /**
+ * Get authentication header value.
+ *
+ * @return authentication header value with bearer token
+ */
+ public String getAuthHeaderValue() {
+ return "Bearer " + bearerToken;
+ }
}
diff --git a/src/main/java/io/github/ollama4j/models/request/CustomModelFileContentsRequest.java b/src/main/java/io/github/ollama4j/models/request/CustomModelFileContentsRequest.java
index 6841476..52bc684 100644
--- a/src/main/java/io/github/ollama4j/models/request/CustomModelFileContentsRequest.java
+++ b/src/main/java/io/github/ollama4j/models/request/CustomModelFileContentsRequest.java
@@ -1,11 +1,11 @@
package io.github.ollama4j.models.request;
-import static io.github.ollama4j.utils.Utils.getObjectMapper;
-
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.AllArgsConstructor;
import lombok.Data;
+import static io.github.ollama4j.utils.Utils.getObjectMapper;
+
@Data
@AllArgsConstructor
public class CustomModelFileContentsRequest {
diff --git a/src/main/java/io/github/ollama4j/models/request/CustomModelFilePathRequest.java b/src/main/java/io/github/ollama4j/models/request/CustomModelFilePathRequest.java
index 2fcda43..578e1c0 100644
--- a/src/main/java/io/github/ollama4j/models/request/CustomModelFilePathRequest.java
+++ b/src/main/java/io/github/ollama4j/models/request/CustomModelFilePathRequest.java
@@ -1,11 +1,11 @@
package io.github.ollama4j.models.request;
-import static io.github.ollama4j.utils.Utils.getObjectMapper;
-
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.AllArgsConstructor;
import lombok.Data;
+import static io.github.ollama4j.utils.Utils.getObjectMapper;
+
@Data
@AllArgsConstructor
public class CustomModelFilePathRequest {
diff --git a/src/main/java/io/github/ollama4j/models/request/CustomModelRequest.java b/src/main/java/io/github/ollama4j/models/request/CustomModelRequest.java
index 15725f0..b2ecb91 100644
--- a/src/main/java/io/github/ollama4j/models/request/CustomModelRequest.java
+++ b/src/main/java/io/github/ollama4j/models/request/CustomModelRequest.java
@@ -1,17 +1,15 @@
package io.github.ollama4j.models.request;
-import static io.github.ollama4j.utils.Utils.getObjectMapper;
-
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.Data;
-import lombok.AllArgsConstructor;
import lombok.Builder;
+import lombok.Data;
import java.util.List;
import java.util.Map;
+import static io.github.ollama4j.utils.Utils.getObjectMapper;
+
@Data
@AllArgsConstructor
diff --git a/src/main/java/io/github/ollama4j/models/request/ModelRequest.java b/src/main/java/io/github/ollama4j/models/request/ModelRequest.java
index 923cd87..eca4d41 100644
--- a/src/main/java/io/github/ollama4j/models/request/ModelRequest.java
+++ b/src/main/java/io/github/ollama4j/models/request/ModelRequest.java
@@ -1,11 +1,11 @@
package io.github.ollama4j.models.request;
-import static io.github.ollama4j.utils.Utils.getObjectMapper;
-
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.AllArgsConstructor;
import lombok.Data;
+import static io.github.ollama4j.utils.Utils.getObjectMapper;
+
@Data
@AllArgsConstructor
public class ModelRequest {
diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java
index 09a3870..724e028 100644
--- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java
+++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java
@@ -24,6 +24,7 @@ import java.util.List;
/**
* Specialization class for requests
*/
+@SuppressWarnings("resource")
public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
@@ -46,19 +47,24 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
* in case the JSON Object cannot be parsed to a {@link OllamaChatResponseModel}. Thus, the ResponseModel should
* never be null.
*
- * @param line streamed line of ollama stream response
+ * @param line streamed line of ollama stream response
* @param responseBuffer Stringbuffer to add latest response message part to
* @return TRUE, if ollama-Response has 'done' state
*/
@Override
- protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
+ protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) {
try {
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
// it seems that under heavy load ollama responds with an empty chat message part in the streamed response
// thus, we null check the message and hope that the next streamed response has some message content again
OllamaChatMessage message = ollamaResponseModel.getMessage();
- if(message != null) {
- responseBuffer.append(message.getContent());
+ if (message != null) {
+ if (message.getThinking() != null) {
+ thinkingBuffer.append(message.getThinking());
+ }
+ else {
+ responseBuffer.append(message.getContent());
+ }
if (tokenHandler != null) {
tokenHandler.accept(ollamaResponseModel);
}
@@ -85,13 +91,14 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
- if (isVerbose()) LOG.info("Asking model: " + body);
+ if (isVerbose()) LOG.info("Asking model: {}", body);
HttpResponse response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
+ StringBuilder thinkingBuffer = new StringBuilder();
OllamaChatResponseModel ollamaChatResponseModel = null;
List wantedToolsForStream = null;
try (BufferedReader reader =
@@ -115,14 +122,20 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
+ } else if (statusCode == 500) {
+ LOG.warn("Status code: 500 (Internal Server Error)");
+ OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
+ OllamaErrorResponse.class);
+ responseBuffer.append(ollamaResponseModel.getError());
} else {
- boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
- ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
- if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){
+ boolean finished = parseResponseAndAddToBuffer(line, responseBuffer, thinkingBuffer);
+ ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
+ if (body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null) {
wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
}
if (finished && body.stream) {
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
+ ollamaChatResponseModel.getMessage().setThinking(thinkingBuffer.toString());
break;
}
}
@@ -132,11 +145,11 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
- if(wantedToolsForStream != null) {
+ if (wantedToolsForStream != null) {
ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
}
OllamaChatResult ollamaResult =
- new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
+ new OllamaChatResult(ollamaChatResponseModel, body.getMessages());
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
return ollamaResult;
}
diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaCommonRequest.java b/src/main/java/io/github/ollama4j/models/request/OllamaCommonRequest.java
index 0ab6cbc..879d801 100644
--- a/src/main/java/io/github/ollama4j/models/request/OllamaCommonRequest.java
+++ b/src/main/java/io/github/ollama4j/models/request/OllamaCommonRequest.java
@@ -1,15 +1,15 @@
package io.github.ollama4j.models.request;
-import java.util.Map;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
-
import io.github.ollama4j.utils.BooleanToJsonFormatFlagSerializer;
import io.github.ollama4j.utils.Utils;
import lombok.Data;
+import java.util.Map;
+
@Data
@JsonInclude(JsonInclude.Include.NON_NULL)
public abstract class OllamaCommonRequest {
diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java
index 1f42ef8..c7bdba0 100644
--- a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java
+++ b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java
@@ -1,15 +1,15 @@
package io.github.ollama4j.models.request;
+import io.github.ollama4j.OllamaAPI;
+import io.github.ollama4j.utils.Constants;
+import lombok.Getter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
import java.net.URI;
import java.net.http.HttpRequest;
import java.time.Duration;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import io.github.ollama4j.OllamaAPI;
-import lombok.Getter;
-
/**
* Abstract helperclass to call the ollama api server.
*/
@@ -32,7 +32,7 @@ public abstract class OllamaEndpointCaller {
protected abstract String getEndpointSuffix();
- protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
+ protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer);
/**
@@ -44,7 +44,7 @@ public abstract class OllamaEndpointCaller {
protected HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder =
HttpRequest.newBuilder(uri)
- .header("Content-Type", "application/json")
+ .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.timeout(Duration.ofSeconds(this.requestTimeoutSeconds));
if (isAuthCredentialsSet()) {
requestBuilder.header("Authorization", this.auth.getAuthHeaderValue());
diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java
index 461ec75..a63a384 100644
--- a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java
+++ b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java
@@ -2,11 +2,11 @@ package io.github.ollama4j.models.request;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.exceptions.OllamaBaseException;
-import io.github.ollama4j.models.response.OllamaErrorResponse;
-import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
import io.github.ollama4j.models.generate.OllamaStreamHandler;
+import io.github.ollama4j.models.response.OllamaErrorResponse;
+import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.utils.OllamaRequestBody;
import io.github.ollama4j.utils.Utils;
import org.slf4j.Logger;
@@ -22,11 +22,12 @@ import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
+@SuppressWarnings("resource")
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
- private OllamaGenerateStreamObserver streamObserver;
+ private OllamaGenerateStreamObserver responseStreamObserver;
public OllamaGenerateEndpointCaller(String host, Auth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose);
@@ -38,12 +39,17 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
}
@Override
- protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
+ protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) {
try {
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
- responseBuffer.append(ollamaResponseModel.getResponse());
- if (streamObserver != null) {
- streamObserver.notify(ollamaResponseModel);
+ if (ollamaResponseModel.getResponse() != null) {
+ responseBuffer.append(ollamaResponseModel.getResponse());
+ }
+ if (ollamaResponseModel.getThinking() != null) {
+ thinkingBuffer.append(ollamaResponseModel.getThinking());
+ }
+ if (responseStreamObserver != null) {
+ responseStreamObserver.notify(ollamaResponseModel);
}
return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) {
@@ -52,9 +58,8 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
}
}
- public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
- throws OllamaBaseException, IOException, InterruptedException {
- streamObserver = new OllamaGenerateStreamObserver(streamHandler);
+ public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) throws OllamaBaseException, IOException, InterruptedException {
+ responseStreamObserver = new OllamaGenerateStreamObserver(thinkingStreamHandler, responseStreamHandler);
return callSync(body);
}
@@ -67,46 +72,41 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
+ @SuppressWarnings("DuplicatedCode")
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
// Create Request
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(getHost() + getEndpointSuffix());
- HttpRequest.Builder requestBuilder =
- getRequestBuilderDefault(uri)
- .POST(
- body.getBodyPublisher());
+ HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).POST(body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
- if (isVerbose()) LOG.info("Asking model: " + body.toString());
- HttpResponse response =
- httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
+ if (isVerbose()) LOG.info("Asking model: {}", body);
+ HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
- try (BufferedReader reader =
- new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
+ StringBuilder thinkingBuffer = new StringBuilder();
+ OllamaGenerateResponseModel ollamaGenerateResponseModel = null;
+ try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)");
- OllamaErrorResponse ollamaResponseModel =
- Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
+ OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)");
- OllamaErrorResponse ollamaResponseModel =
- Utils.getObjectMapper()
- .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
+ OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)");
- OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
- OllamaErrorResponse.class);
+ OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
- boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
+ boolean finished = parseResponseAndAddToBuffer(line, responseBuffer, thinkingBuffer);
if (finished) {
+ ollamaGenerateResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
break;
}
}
@@ -114,13 +114,25 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
}
if (statusCode != 200) {
- LOG.error("Status code " + statusCode);
+ LOG.error("Status code: {}", statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
- OllamaResult ollamaResult =
- new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
- if (isVerbose()) LOG.info("Model response: " + ollamaResult);
+ OllamaResult ollamaResult = new OllamaResult(responseBuffer.toString(), thinkingBuffer.toString(), endTime - startTime, statusCode);
+
+ ollamaResult.setModel(ollamaGenerateResponseModel.getModel());
+ ollamaResult.setCreatedAt(ollamaGenerateResponseModel.getCreatedAt());
+ ollamaResult.setDone(ollamaGenerateResponseModel.isDone());
+ ollamaResult.setDoneReason(ollamaGenerateResponseModel.getDoneReason());
+ ollamaResult.setContext(ollamaGenerateResponseModel.getContext());
+ ollamaResult.setTotalDuration(ollamaGenerateResponseModel.getTotalDuration());
+ ollamaResult.setLoadDuration(ollamaGenerateResponseModel.getLoadDuration());
+ ollamaResult.setPromptEvalCount(ollamaGenerateResponseModel.getPromptEvalCount());
+ ollamaResult.setPromptEvalDuration(ollamaGenerateResponseModel.getPromptEvalDuration());
+ ollamaResult.setEvalCount(ollamaGenerateResponseModel.getEvalCount());
+ ollamaResult.setEvalDuration(ollamaGenerateResponseModel.getEvalDuration());
+
+ if (isVerbose()) LOG.info("Model response: {}", ollamaResult);
return ollamaResult;
}
}
diff --git a/src/main/java/io/github/ollama4j/models/response/LibraryModel.java b/src/main/java/io/github/ollama4j/models/response/LibraryModel.java
index 82aba42..c5f1627 100644
--- a/src/main/java/io/github/ollama4j/models/response/LibraryModel.java
+++ b/src/main/java/io/github/ollama4j/models/response/LibraryModel.java
@@ -1,9 +1,10 @@
package io.github.ollama4j.models.response;
-import java.util.ArrayList;
-import java.util.List;
import lombok.Data;
+import java.util.ArrayList;
+import java.util.List;
+
@Data
public class LibraryModel {
diff --git a/src/main/java/io/github/ollama4j/models/response/LibraryModelTag.java b/src/main/java/io/github/ollama4j/models/response/LibraryModelTag.java
index d720dd0..cd65d32 100644
--- a/src/main/java/io/github/ollama4j/models/response/LibraryModelTag.java
+++ b/src/main/java/io/github/ollama4j/models/response/LibraryModelTag.java
@@ -2,8 +2,6 @@ package io.github.ollama4j.models.response;
import lombok.Data;
-import java.util.List;
-
@Data
public class LibraryModelTag {
private String name;
diff --git a/src/main/java/io/github/ollama4j/models/response/ListModelsResponse.java b/src/main/java/io/github/ollama4j/models/response/ListModelsResponse.java
index 62f151b..e22b796 100644
--- a/src/main/java/io/github/ollama4j/models/response/ListModelsResponse.java
+++ b/src/main/java/io/github/ollama4j/models/response/ListModelsResponse.java
@@ -1,9 +1,9 @@
package io.github.ollama4j.models.response;
-import java.util.List;
-
import lombok.Data;
+import java.util.List;
+
@Data
public class ListModelsResponse {
private List models;
diff --git a/src/main/java/io/github/ollama4j/models/response/Model.java b/src/main/java/io/github/ollama4j/models/response/Model.java
index ae64f38..a616404 100644
--- a/src/main/java/io/github/ollama4j/models/response/Model.java
+++ b/src/main/java/io/github/ollama4j/models/response/Model.java
@@ -1,13 +1,13 @@
package io.github.ollama4j.models.response;
-import java.time.OffsetDateTime;
-
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.utils.Utils;
import lombok.Data;
+import java.time.OffsetDateTime;
+
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class Model {
diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaAsyncResultStreamer.java b/src/main/java/io/github/ollama4j/models/response/OllamaAsyncResultStreamer.java
index fd43696..f4a68f7 100644
--- a/src/main/java/io/github/ollama4j/models/response/OllamaAsyncResultStreamer.java
+++ b/src/main/java/io/github/ollama4j/models/response/OllamaAsyncResultStreamer.java
@@ -3,6 +3,7 @@ package io.github.ollama4j.models.response;
import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
+import io.github.ollama4j.utils.Constants;
import io.github.ollama4j.utils.Utils;
import lombok.Data;
import lombok.EqualsAndHashCode;
@@ -25,8 +26,10 @@ import java.time.Duration;
public class OllamaAsyncResultStreamer extends Thread {
private final HttpRequest.Builder requestBuilder;
private final OllamaGenerateRequest ollamaRequestModel;
- private final OllamaResultStream stream = new OllamaResultStream();
+ private final OllamaResultStream thinkingResponseStream = new OllamaResultStream();
+ private final OllamaResultStream responseStream = new OllamaResultStream();
private String completeResponse;
+ private String completeThinkingResponse;
/**
@@ -53,14 +56,11 @@ public class OllamaAsyncResultStreamer extends Thread {
@Getter
private long responseTime = 0;
- public OllamaAsyncResultStreamer(
- HttpRequest.Builder requestBuilder,
- OllamaGenerateRequest ollamaRequestModel,
- long requestTimeoutSeconds) {
+ public OllamaAsyncResultStreamer(HttpRequest.Builder requestBuilder, OllamaGenerateRequest ollamaRequestModel, long requestTimeoutSeconds) {
this.requestBuilder = requestBuilder;
this.ollamaRequestModel = ollamaRequestModel;
this.completeResponse = "";
- this.stream.add("");
+ this.responseStream.add("");
this.requestTimeoutSeconds = requestTimeoutSeconds;
}
@@ -68,47 +68,63 @@ public class OllamaAsyncResultStreamer extends Thread {
public void run() {
ollamaRequestModel.setStream(true);
HttpClient httpClient = HttpClient.newHttpClient();
+ long startTime = System.currentTimeMillis();
try {
- long startTime = System.currentTimeMillis();
- HttpRequest request =
- requestBuilder
- .POST(
- HttpRequest.BodyPublishers.ofString(
- Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
- .header("Content-Type", "application/json")
- .timeout(Duration.ofSeconds(requestTimeoutSeconds))
- .build();
- HttpResponse response =
- httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
+ HttpRequest request = requestBuilder.POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).timeout(Duration.ofSeconds(requestTimeoutSeconds)).build();
+ HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
this.httpStatusCode = statusCode;
InputStream responseBodyStream = response.body();
- try (BufferedReader reader =
- new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
+ BufferedReader reader = null;
+ try {
+ reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8));
String line;
+ StringBuilder thinkingBuffer = new StringBuilder();
StringBuilder responseBuffer = new StringBuilder();
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
- OllamaErrorResponse ollamaResponseModel =
- Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
- stream.add(ollamaResponseModel.getError());
+ OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
+ responseStream.add(ollamaResponseModel.getError());
responseBuffer.append(ollamaResponseModel.getError());
} else {
- OllamaGenerateResponseModel ollamaResponseModel =
- Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
- String res = ollamaResponseModel.getResponse();
- stream.add(res);
+ OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
+ String thinkingTokens = ollamaResponseModel.getThinking();
+ String responseTokens = ollamaResponseModel.getResponse();
+ if (thinkingTokens == null) {
+ thinkingTokens = "";
+ }
+ if (responseTokens == null) {
+ responseTokens = "";
+ }
+ thinkingResponseStream.add(thinkingTokens);
+ responseStream.add(responseTokens);
if (!ollamaResponseModel.isDone()) {
- responseBuffer.append(res);
+ responseBuffer.append(responseTokens);
+ thinkingBuffer.append(thinkingTokens);
}
}
}
-
this.succeeded = true;
+ this.completeThinkingResponse = thinkingBuffer.toString();
this.completeResponse = responseBuffer.toString();
long endTime = System.currentTimeMillis();
responseTime = endTime - startTime;
+ } finally {
+ if (reader != null) {
+ try {
+ reader.close();
+ } catch (IOException e) {
+ // Optionally log or handle
+ }
+ }
+ if (responseBodyStream != null) {
+ try {
+ responseBodyStream.close();
+ } catch (IOException e) {
+ // Optionally log or handle
+ }
+ }
}
if (statusCode != 200) {
throw new OllamaBaseException(this.completeResponse);
diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java
index 4b538f9..ce6d5e3 100644
--- a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java
+++ b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java
@@ -3,116 +3,136 @@ package io.github.ollama4j.models.response;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
-
import lombok.Data;
import lombok.Getter;
-import static io.github.ollama4j.utils.Utils.getObjectMapper;
-
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
-/** The type Ollama result. */
+import static io.github.ollama4j.utils.Utils.getObjectMapper;
+
+/**
+ * The type Ollama result.
+ */
@Getter
@SuppressWarnings("unused")
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaResult {
- /**
- * -- GETTER --
- * Get the completion/response text
- *
- * @return String completion/response text
- */
- private final String response;
+ /**
+ * Get the completion/response text
+ */
+ private final String response;
+ /**
+ * Get the thinking text (if available)
+ */
+ private final String thinking;
+ /**
+ * Get the response status code.
+ */
+ private int httpStatusCode;
+ /**
+ * Get the response time in milliseconds.
+ */
+ private long responseTime = 0;
- /**
- * -- GETTER --
- * Get the response status code.
- *
- * @return int - response status code
- */
- private int httpStatusCode;
+ private String model;
+ private String createdAt;
+ private boolean done;
+ private String doneReason;
+ private List context;
+ private Long totalDuration;
+ private Long loadDuration;
+ private Integer promptEvalCount;
+ private Long promptEvalDuration;
+ private Integer evalCount;
+ private Long evalDuration;
- /**
- * -- GETTER --
- * Get the response time in milliseconds.
- *
- * @return long - response time in milliseconds
- */
- private long responseTime = 0;
-
- public OllamaResult(String response, long responseTime, int httpStatusCode) {
- this.response = response;
- this.responseTime = responseTime;
- this.httpStatusCode = httpStatusCode;
- }
-
- @Override
- public String toString() {
- try {
- Map responseMap = new HashMap<>();
- responseMap.put("response", this.response);
- responseMap.put("httpStatusCode", this.httpStatusCode);
- responseMap.put("responseTime", this.responseTime);
- return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap);
- } catch (JsonProcessingException e) {
- throw new RuntimeException(e);
- }
- }
-
- /**
- * Get the structured response if the response is a JSON object.
- *
- * @return Map - structured response
- * @throws IllegalArgumentException if the response is not a valid JSON object
- */
- public Map getStructuredResponse() {
- String responseStr = this.getResponse();
- if (responseStr == null || responseStr.trim().isEmpty()) {
- throw new IllegalArgumentException("Response is empty or null");
+ public OllamaResult(String response, String thinking, long responseTime, int httpStatusCode) {
+ this.response = response;
+ this.thinking = thinking;
+ this.responseTime = responseTime;
+ this.httpStatusCode = httpStatusCode;
}
- try {
- // Check if the response is a valid JSON
- if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) ||
- (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) {
- throw new IllegalArgumentException("Response is not a valid JSON object");
- }
-
- Map response = getObjectMapper().readValue(responseStr,
- new TypeReference