forked from Mirror/ollama4j
- Addresses #231 - Updated Ollama class and related methods to replace boolean "think" with ThinkMode enum for better clarity and control over thinking levels. - Modified MetricsRecorder to accept ThinkMode instead of boolean for metrics recording. - Adjusted OllamaChatRequest and OllamaGenerateRequest to utilize ThinkMode, including serialization support. - Updated integration and unit tests to reflect changes in the "think" parameter handling. - Introduced ThinkMode and ThinkModeSerializer classes to manage the new thinking parameter structure.
1318 lines
54 KiB
Java
1318 lines
54 KiB
Java
/*
|
|
* Ollama4j - Java library for interacting with Ollama server.
|
|
* Copyright (c) 2025 Amith Koujalgi and contributors.
|
|
*
|
|
* Licensed under the MIT License (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
*
|
|
*/
|
|
package io.github.ollama4j;
|
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
import io.github.ollama4j.exceptions.OllamaException;
|
|
import io.github.ollama4j.exceptions.RoleNotFoundException;
|
|
import io.github.ollama4j.exceptions.ToolInvocationException;
|
|
import io.github.ollama4j.metrics.MetricsRecorder;
|
|
import io.github.ollama4j.models.chat.*;
|
|
import io.github.ollama4j.models.chat.OllamaChatTokenHandler;
|
|
import io.github.ollama4j.models.embed.OllamaEmbedRequest;
|
|
import io.github.ollama4j.models.embed.OllamaEmbedResult;
|
|
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
|
|
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
|
|
import io.github.ollama4j.models.generate.OllamaGenerateTokenHandler;
|
|
import io.github.ollama4j.models.ps.ModelProcessesResult;
|
|
import io.github.ollama4j.models.request.*;
|
|
import io.github.ollama4j.models.response.*;
|
|
import io.github.ollama4j.tools.*;
|
|
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
|
import io.github.ollama4j.tools.annotations.ToolProperty;
|
|
import io.github.ollama4j.tools.annotations.ToolSpec;
|
|
import io.github.ollama4j.utils.Constants;
|
|
import io.github.ollama4j.utils.Utils;
|
|
import java.io.*;
|
|
import java.lang.reflect.InvocationTargetException;
|
|
import java.lang.reflect.Method;
|
|
import java.lang.reflect.Parameter;
|
|
import java.net.URI;
|
|
import java.net.URISyntaxException;
|
|
import java.net.http.HttpClient;
|
|
import java.net.http.HttpRequest;
|
|
import java.net.http.HttpResponse;
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.nio.file.Files;
|
|
import java.time.Duration;
|
|
import java.util.*;
|
|
import java.util.stream.Collectors;
|
|
import lombok.Setter;
|
|
import org.slf4j.Logger;
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
/**
|
|
* The main API class for interacting with the Ollama server.
|
|
*
|
|
* <p>This class provides methods for model management, chat, embeddings, tool registration, and more.
|
|
*/
|
|
@SuppressWarnings({"DuplicatedCode", "resource", "SpellCheckingInspection"})
|
|
public class Ollama {
|
|
|
|
private static final Logger LOG = LoggerFactory.getLogger(Ollama.class);
|
|
|
|
private final String host;
|
|
private Auth auth;
|
|
|
|
private final ToolRegistry toolRegistry = new ToolRegistry();
|
|
|
|
/**
|
|
* The request timeout in seconds for API calls.
|
|
* <p>
|
|
* Default is 10 seconds. This value determines how long the client will wait for a response
|
|
* from the Ollama server before timing out.
|
|
*/
|
|
@Setter private long requestTimeoutSeconds = 10;
|
|
|
|
/**
|
|
* The read timeout in seconds for image URLs.
|
|
*/
|
|
@Setter private int imageURLReadTimeoutSeconds = 10;
|
|
|
|
/**
|
|
* The connect timeout in seconds for image URLs.
|
|
*/
|
|
@Setter private int imageURLConnectTimeoutSeconds = 10;
|
|
|
|
/**
|
|
* The maximum number of retries for tool calls during chat interactions.
|
|
* <p>
|
|
* This value controls how many times the API will attempt to call a tool in the event of a
|
|
* failure. Default is 3.
|
|
*/
|
|
@Setter private int maxChatToolCallRetries = 3;
|
|
|
|
/**
|
|
* The number of retries to attempt when pulling a model from the Ollama server.
|
|
* <p>
|
|
* If set to 0, no retries will be performed. If greater than 0, the API will retry pulling
|
|
* the model up to the specified number of times in case of failure.
|
|
* <p>
|
|
* Default is 0 (no retries).
|
|
*/
|
|
@Setter
|
|
@SuppressWarnings({"FieldMayBeFinal", "FieldCanBeLocal"})
|
|
private int numberOfRetriesForModelPull = 0;
|
|
|
|
/**
|
|
* Enable or disable Prometheus metrics collection.
|
|
* <p>
|
|
* When enabled, the API will collect and expose metrics for request counts, durations, model
|
|
* usage, and other operational statistics. Default is false.
|
|
*/
|
|
@Setter private boolean metricsEnabled = false;
|
|
|
|
/**
|
|
* Instantiates the Ollama API with the default Ollama host: {@code http://localhost:11434}
|
|
*/
|
|
public Ollama() {
|
|
this.host = "http://localhost:11434";
|
|
}
|
|
|
|
/**
|
|
* Instantiates the Ollama API with a specified Ollama host address.
|
|
*
|
|
* @param host the host address of the Ollama server
|
|
*/
|
|
public Ollama(String host) {
|
|
if (host.endsWith("/")) {
|
|
this.host = host.substring(0, host.length() - 1);
|
|
} else {
|
|
this.host = host;
|
|
}
|
|
LOG.info("Ollama4j client initialized. Connected to Ollama server at: {}", this.host);
|
|
}
|
|
|
|
/**
|
|
* Set basic authentication for accessing an Ollama server that's behind a reverse-proxy/gateway.
|
|
*
|
|
* @param username the username
|
|
* @param password the password
|
|
*/
|
|
public void setBasicAuth(String username, String password) {
|
|
this.auth = new BasicAuth(username, password);
|
|
}
|
|
|
|
/**
|
|
* Set Bearer authentication for accessing an Ollama server that's behind a reverse-proxy/gateway.
|
|
*
|
|
* @param bearerToken the Bearer authentication token to provide
|
|
*/
|
|
public void setBearerAuth(String bearerToken) {
|
|
this.auth = new BearerAuth(bearerToken);
|
|
}
|
|
|
|
/**
|
|
* Checks the reachability of the Ollama server.
|
|
*
|
|
* @return true if the server is reachable, false otherwise
|
|
* @throws OllamaException if the ping fails
|
|
*/
|
|
public boolean ping() throws OllamaException {
|
|
long startTime = System.currentTimeMillis();
|
|
String url = "/api/tags";
|
|
int statusCode = -1;
|
|
Object out = null;
|
|
try {
|
|
HttpClient httpClient = HttpClient.newHttpClient();
|
|
HttpRequest httpRequest;
|
|
HttpResponse<String> response;
|
|
httpRequest =
|
|
getRequestBuilderDefault(new URI(this.host + url))
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_ACCEPT,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.GET()
|
|
.build();
|
|
response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
|
statusCode = response.statusCode();
|
|
return statusCode == 200;
|
|
} catch (InterruptedException ie) {
|
|
Thread.currentThread().interrupt();
|
|
throw new OllamaException("Ping interrupted", ie);
|
|
} catch (Exception e) {
|
|
throw new OllamaException("Ping failed", e);
|
|
} finally {
|
|
MetricsRecorder.record(
|
|
url,
|
|
"",
|
|
false,
|
|
ThinkMode.DISABLED,
|
|
false,
|
|
null,
|
|
null,
|
|
startTime,
|
|
statusCode,
|
|
out);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Provides a list of running models and details about each model currently loaded into memory.
|
|
*
|
|
* @return ModelsProcessResult containing details about the running models
|
|
* @throws OllamaException if the response indicates an error status
|
|
*/
|
|
public ModelProcessesResult ps() throws OllamaException {
|
|
long startTime = System.currentTimeMillis();
|
|
String url = "/api/ps";
|
|
int statusCode = -1;
|
|
Object out = null;
|
|
try {
|
|
HttpClient httpClient = HttpClient.newHttpClient();
|
|
HttpRequest httpRequest = null;
|
|
try {
|
|
httpRequest =
|
|
getRequestBuilderDefault(new URI(this.host + url))
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_ACCEPT,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.GET()
|
|
.build();
|
|
} catch (URISyntaxException e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
}
|
|
HttpResponse<String> response = null;
|
|
response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
|
statusCode = response.statusCode();
|
|
String responseString = response.body();
|
|
if (statusCode == 200) {
|
|
return Utils.getObjectMapper()
|
|
.readValue(responseString, ModelProcessesResult.class);
|
|
} else {
|
|
throw new OllamaException(statusCode + " - " + responseString);
|
|
}
|
|
} catch (InterruptedException ie) {
|
|
Thread.currentThread().interrupt();
|
|
throw new OllamaException("ps interrupted", ie);
|
|
} catch (Exception e) {
|
|
throw new OllamaException("ps failed", e);
|
|
} finally {
|
|
MetricsRecorder.record(
|
|
url,
|
|
"",
|
|
false,
|
|
ThinkMode.DISABLED,
|
|
false,
|
|
null,
|
|
null,
|
|
startTime,
|
|
statusCode,
|
|
out);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Lists available models from the Ollama server.
|
|
*
|
|
* @return a list of models available on the server
|
|
* @throws OllamaException if the response indicates an error status
|
|
*/
|
|
public List<Model> listModels() throws OllamaException {
|
|
long startTime = System.currentTimeMillis();
|
|
String url = "/api/tags";
|
|
int statusCode = -1;
|
|
Object out = null;
|
|
try {
|
|
HttpClient httpClient = HttpClient.newHttpClient();
|
|
HttpRequest httpRequest =
|
|
getRequestBuilderDefault(new URI(this.host + url))
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_ACCEPT,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.GET()
|
|
.build();
|
|
HttpResponse<String> response =
|
|
httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
|
statusCode = response.statusCode();
|
|
String responseString = response.body();
|
|
if (statusCode == 200) {
|
|
return Utils.getObjectMapper()
|
|
.readValue(responseString, ListModelsResponse.class)
|
|
.getModels();
|
|
} else {
|
|
throw new OllamaException(statusCode + " - " + responseString);
|
|
}
|
|
} catch (InterruptedException ie) {
|
|
Thread.currentThread().interrupt();
|
|
throw new OllamaException("listModels interrupted", ie);
|
|
} catch (Exception e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
} finally {
|
|
MetricsRecorder.record(
|
|
url,
|
|
"",
|
|
false,
|
|
ThinkMode.DISABLED,
|
|
false,
|
|
null,
|
|
null,
|
|
startTime,
|
|
statusCode,
|
|
out);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Handles retry backoff for pullModel.
|
|
*
|
|
* @param modelName the name of the model being pulled
|
|
* @param currentRetry the current retry attempt (zero-based)
|
|
* @param maxRetries the maximum number of retries allowed
|
|
* @param baseDelayMillis the base delay in milliseconds for exponential backoff
|
|
* @throws InterruptedException if the thread is interrupted during sleep
|
|
*/
|
|
private void handlePullRetry(
|
|
String modelName, int currentRetry, int maxRetries, long baseDelayMillis)
|
|
throws InterruptedException {
|
|
int attempt = currentRetry + 1;
|
|
if (attempt < maxRetries) {
|
|
long backoffMillis = baseDelayMillis * (1L << currentRetry);
|
|
LOG.error(
|
|
"Failed to pull model {}, retrying in {}s... (attempt {}/{})",
|
|
modelName,
|
|
backoffMillis / 1000,
|
|
attempt,
|
|
maxRetries);
|
|
try {
|
|
Thread.sleep(backoffMillis);
|
|
} catch (InterruptedException ie) {
|
|
Thread.currentThread().interrupt();
|
|
throw ie;
|
|
}
|
|
} else {
|
|
LOG.error(
|
|
"Failed to pull model {} after {} attempts, no more retries.",
|
|
modelName,
|
|
maxRetries);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Internal method to pull a model from the Ollama server.
|
|
*
|
|
* @param modelName the name of the model to pull
|
|
* @throws OllamaException if the pull fails
|
|
*/
|
|
private void doPullModel(String modelName) throws OllamaException {
|
|
long startTime = System.currentTimeMillis();
|
|
String url = "/api/pull";
|
|
int statusCode = -1;
|
|
Object out = null;
|
|
try {
|
|
String jsonData = new ModelRequest(modelName).toString();
|
|
HttpRequest request =
|
|
getRequestBuilderDefault(new URI(this.host + url))
|
|
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_ACCEPT,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.build();
|
|
HttpClient client = HttpClient.newHttpClient();
|
|
HttpResponse<InputStream> response =
|
|
client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
|
statusCode = response.statusCode();
|
|
InputStream responseBodyStream = response.body();
|
|
String responseString = "";
|
|
boolean success = false; // Flag to check the pull success.
|
|
|
|
try (BufferedReader reader =
|
|
new BufferedReader(
|
|
new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
|
String line;
|
|
while ((line = reader.readLine()) != null) {
|
|
ModelPullResponse modelPullResponse =
|
|
Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
|
|
success = processModelPullResponse(modelPullResponse, modelName) || success;
|
|
}
|
|
}
|
|
if (!success) {
|
|
LOG.error("Model pull failed or returned invalid status.");
|
|
throw new OllamaException("Model pull failed or returned invalid status.");
|
|
}
|
|
if (statusCode != 200) {
|
|
throw new OllamaException(statusCode + " - " + responseString);
|
|
}
|
|
} catch (InterruptedException ie) {
|
|
Thread.currentThread().interrupt();
|
|
throw new OllamaException("Thread was interrupted during model pull.", ie);
|
|
} catch (Exception e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
} finally {
|
|
MetricsRecorder.record(
|
|
url,
|
|
"",
|
|
false,
|
|
ThinkMode.DISABLED,
|
|
false,
|
|
null,
|
|
null,
|
|
startTime,
|
|
statusCode,
|
|
out);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Processes a single ModelPullResponse, handling errors and logging status.
|
|
* Returns true if the response indicates a successful pull.
|
|
*
|
|
* @param modelPullResponse the response from the model pull
|
|
* @param modelName the name of the model
|
|
* @return true if the pull was successful, false otherwise
|
|
* @throws OllamaException if the response contains an error
|
|
*/
|
|
@SuppressWarnings("RedundantIfStatement")
|
|
private boolean processModelPullResponse(ModelPullResponse modelPullResponse, String modelName)
|
|
throws OllamaException {
|
|
if (modelPullResponse == null) {
|
|
LOG.error("Received null response for model pull.");
|
|
return false;
|
|
}
|
|
String error = modelPullResponse.getError();
|
|
if (error != null && !error.trim().isEmpty()) {
|
|
throw new OllamaException("Model pull failed: " + error);
|
|
}
|
|
String status = modelPullResponse.getStatus();
|
|
if (status != null) {
|
|
LOG.debug("{}: {}", modelName, status);
|
|
if ("success".equalsIgnoreCase(status)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/**
|
|
* Gets the Ollama server version.
|
|
*
|
|
* @return the version string
|
|
* @throws OllamaException if the request fails
|
|
*/
|
|
public String getVersion() throws OllamaException {
|
|
String url = "/api/version";
|
|
long startTime = System.currentTimeMillis();
|
|
int statusCode = -1;
|
|
Object out = null;
|
|
try {
|
|
HttpClient httpClient = HttpClient.newHttpClient();
|
|
HttpRequest httpRequest =
|
|
getRequestBuilderDefault(new URI(this.host + url))
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_ACCEPT,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.GET()
|
|
.build();
|
|
HttpResponse<String> response =
|
|
httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
|
statusCode = response.statusCode();
|
|
String responseString = response.body();
|
|
if (statusCode == 200) {
|
|
return Utils.getObjectMapper()
|
|
.readValue(responseString, OllamaVersion.class)
|
|
.getVersion();
|
|
} else {
|
|
throw new OllamaException(statusCode + " - " + responseString);
|
|
}
|
|
} catch (InterruptedException ie) {
|
|
Thread.currentThread().interrupt();
|
|
throw new OllamaException("Thread was interrupted", ie);
|
|
} catch (Exception e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
} finally {
|
|
MetricsRecorder.record(
|
|
url,
|
|
"",
|
|
false,
|
|
ThinkMode.DISABLED,
|
|
false,
|
|
null,
|
|
null,
|
|
startTime,
|
|
statusCode,
|
|
out);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Pulls a model using the specified Ollama library model tag.
|
|
* The model is identified by a name and a tag, which are combined into a single identifier
|
|
* in the format "name:tag" to pull the corresponding model.
|
|
*
|
|
* @param modelName the name/tag of the model to be pulled. Ex: llama3:latest
|
|
* @throws OllamaException if the response indicates an error status
|
|
*/
|
|
public void pullModel(String modelName) throws OllamaException {
|
|
try {
|
|
if (numberOfRetriesForModelPull == 0) {
|
|
this.doPullModel(modelName);
|
|
return;
|
|
}
|
|
int numberOfRetries = 0;
|
|
long baseDelayMillis = 3000L; // 3 seconds base delay
|
|
while (numberOfRetries < numberOfRetriesForModelPull) {
|
|
try {
|
|
this.doPullModel(modelName);
|
|
return;
|
|
} catch (OllamaException e) {
|
|
handlePullRetry(
|
|
modelName,
|
|
numberOfRetries,
|
|
numberOfRetriesForModelPull,
|
|
baseDelayMillis);
|
|
numberOfRetries++;
|
|
}
|
|
}
|
|
throw new OllamaException(
|
|
"Failed to pull model "
|
|
+ modelName
|
|
+ " after "
|
|
+ numberOfRetriesForModelPull
|
|
+ " retries");
|
|
} catch (InterruptedException ie) {
|
|
Thread.currentThread().interrupt();
|
|
throw new OllamaException("Thread was interrupted", ie);
|
|
} catch (Exception e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Gets model details from the Ollama server.
|
|
*
|
|
* @param modelName the model name
|
|
* @return the model details
|
|
* @throws OllamaException if the response indicates an error status
|
|
*/
|
|
public ModelDetail getModelDetails(String modelName) throws OllamaException {
|
|
long startTime = System.currentTimeMillis();
|
|
String url = "/api/show";
|
|
int statusCode = -1;
|
|
Object out = null;
|
|
try {
|
|
String jsonData = new ModelRequest(modelName).toString();
|
|
HttpRequest request =
|
|
getRequestBuilderDefault(new URI(this.host + url))
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_ACCEPT,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
|
.build();
|
|
HttpClient client = HttpClient.newHttpClient();
|
|
HttpResponse<String> response =
|
|
client.send(request, HttpResponse.BodyHandlers.ofString());
|
|
statusCode = response.statusCode();
|
|
String responseBody = response.body();
|
|
if (statusCode == 200) {
|
|
return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
|
|
} else {
|
|
throw new OllamaException(statusCode + " - " + responseBody);
|
|
}
|
|
} catch (InterruptedException ie) {
|
|
Thread.currentThread().interrupt();
|
|
throw new OllamaException("Thread was interrupted", ie);
|
|
} catch (Exception e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
} finally {
|
|
MetricsRecorder.record(
|
|
url,
|
|
"",
|
|
false,
|
|
ThinkMode.DISABLED,
|
|
false,
|
|
null,
|
|
null,
|
|
startTime,
|
|
statusCode,
|
|
out);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Creates a custom model. Read more about custom model creation
|
|
* <a href="https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model">here</a>.
|
|
*
|
|
* @param customModelRequest custom model spec
|
|
* @throws OllamaException if the response indicates an error status
|
|
*/
|
|
public void createModel(CustomModelRequest customModelRequest) throws OllamaException {
|
|
long startTime = System.currentTimeMillis();
|
|
String url = "/api/create";
|
|
int statusCode = -1;
|
|
Object out = null;
|
|
try {
|
|
String jsonData = customModelRequest.toString();
|
|
HttpRequest request =
|
|
getRequestBuilderDefault(new URI(this.host + url))
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_ACCEPT,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.POST(
|
|
HttpRequest.BodyPublishers.ofString(
|
|
jsonData, StandardCharsets.UTF_8))
|
|
.build();
|
|
HttpClient client = HttpClient.newHttpClient();
|
|
HttpResponse<InputStream> response =
|
|
client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
|
statusCode = response.statusCode();
|
|
if (statusCode != 200) {
|
|
String errorBody =
|
|
new String(response.body().readAllBytes(), StandardCharsets.UTF_8);
|
|
out = errorBody;
|
|
throw new OllamaException(statusCode + " - " + errorBody);
|
|
}
|
|
try (BufferedReader reader =
|
|
new BufferedReader(
|
|
new InputStreamReader(response.body(), StandardCharsets.UTF_8))) {
|
|
String line;
|
|
StringBuilder lines = new StringBuilder();
|
|
while ((line = reader.readLine()) != null) {
|
|
ModelPullResponse res =
|
|
Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
|
|
lines.append(line);
|
|
LOG.debug(res.getStatus());
|
|
if (res.getError() != null) {
|
|
out = res.getError();
|
|
throw new OllamaException(res.getError());
|
|
}
|
|
}
|
|
out = lines;
|
|
}
|
|
} catch (InterruptedException e) {
|
|
Thread.currentThread().interrupt();
|
|
throw new OllamaException("Thread was interrupted", e);
|
|
} catch (Exception e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
} finally {
|
|
MetricsRecorder.record(
|
|
url,
|
|
"",
|
|
false,
|
|
ThinkMode.DISABLED,
|
|
false,
|
|
null,
|
|
null,
|
|
startTime,
|
|
statusCode,
|
|
out);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Deletes a model from the Ollama server.
|
|
*
|
|
* @param modelName the name of the model to be deleted
|
|
* @param ignoreIfNotPresent ignore errors if the specified model is not present on the Ollama server
|
|
* @throws OllamaException if the response indicates an error status
|
|
*/
|
|
public void deleteModel(String modelName, boolean ignoreIfNotPresent) throws OllamaException {
|
|
long startTime = System.currentTimeMillis();
|
|
String url = "/api/delete";
|
|
int statusCode = -1;
|
|
Object out = null;
|
|
try {
|
|
String jsonData = new ModelRequest(modelName).toString();
|
|
HttpRequest request =
|
|
getRequestBuilderDefault(new URI(this.host + url))
|
|
.method(
|
|
"DELETE",
|
|
HttpRequest.BodyPublishers.ofString(
|
|
jsonData, StandardCharsets.UTF_8))
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_ACCEPT,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.build();
|
|
HttpClient client = HttpClient.newHttpClient();
|
|
HttpResponse<String> response =
|
|
client.send(request, HttpResponse.BodyHandlers.ofString());
|
|
statusCode = response.statusCode();
|
|
String responseBody = response.body();
|
|
out = responseBody;
|
|
if (statusCode == 404
|
|
&& responseBody.contains("model")
|
|
&& responseBody.contains("not found")) {
|
|
return;
|
|
}
|
|
if (statusCode != 200) {
|
|
throw new OllamaException(statusCode + " - " + responseBody);
|
|
}
|
|
} catch (InterruptedException e) {
|
|
Thread.currentThread().interrupt();
|
|
throw new OllamaException("Thread was interrupted", e);
|
|
} catch (Exception e) {
|
|
throw new OllamaException(statusCode + " - " + out, e);
|
|
} finally {
|
|
MetricsRecorder.record(
|
|
url,
|
|
"",
|
|
false,
|
|
ThinkMode.DISABLED,
|
|
false,
|
|
null,
|
|
null,
|
|
startTime,
|
|
statusCode,
|
|
out);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Unloads a model from memory.
|
|
* <p>
|
|
* If an empty prompt is provided and the keep_alive parameter is set to 0, a model will be
|
|
* unloaded from memory.
|
|
*
|
|
* @param modelName the name of the model to unload
|
|
* @throws OllamaException if the response indicates an error status
|
|
*/
|
|
public void unloadModel(String modelName) throws OllamaException {
|
|
long startTime = System.currentTimeMillis();
|
|
String url = "/api/generate";
|
|
int statusCode = -1;
|
|
Object out = null;
|
|
try {
|
|
ObjectMapper objectMapper = new ObjectMapper();
|
|
Map<String, Object> jsonMap = new java.util.HashMap<>();
|
|
jsonMap.put("model", modelName);
|
|
jsonMap.put("keep_alive", 0);
|
|
String jsonData = objectMapper.writeValueAsString(jsonMap);
|
|
HttpRequest request =
|
|
getRequestBuilderDefault(new URI(this.host + url))
|
|
.method(
|
|
"POST",
|
|
HttpRequest.BodyPublishers.ofString(
|
|
jsonData, StandardCharsets.UTF_8))
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_ACCEPT,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.build();
|
|
LOG.debug("Unloading model with request: {}", jsonData);
|
|
HttpClient client = HttpClient.newHttpClient();
|
|
HttpResponse<String> response =
|
|
client.send(request, HttpResponse.BodyHandlers.ofString());
|
|
statusCode = response.statusCode();
|
|
String responseBody = response.body();
|
|
if (statusCode == 404
|
|
&& responseBody.contains("model")
|
|
&& responseBody.contains("not found")) {
|
|
LOG.debug("Unload response: {} - {}", statusCode, responseBody);
|
|
return;
|
|
}
|
|
if (statusCode != 200) {
|
|
LOG.debug("Unload response: {} - {}", statusCode, responseBody);
|
|
throw new OllamaException(statusCode + " - " + responseBody);
|
|
}
|
|
} catch (InterruptedException e) {
|
|
Thread.currentThread().interrupt();
|
|
LOG.debug("Unload interrupted: {} - {}", statusCode, out);
|
|
throw new OllamaException(statusCode + " - " + out, e);
|
|
} catch (Exception e) {
|
|
LOG.debug("Unload failed: {} - {}", statusCode, out);
|
|
throw new OllamaException(statusCode + " - " + out, e);
|
|
} finally {
|
|
MetricsRecorder.record(
|
|
url,
|
|
"",
|
|
false,
|
|
ThinkMode.DISABLED,
|
|
false,
|
|
null,
|
|
null,
|
|
startTime,
|
|
statusCode,
|
|
out);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Generate embeddings using a {@link OllamaEmbedRequest}.
|
|
*
|
|
* @param modelRequest request for '/api/embed' endpoint
|
|
* @return embeddings
|
|
* @throws OllamaException if the response indicates an error status
|
|
*/
|
|
public OllamaEmbedResult embed(OllamaEmbedRequest modelRequest) throws OllamaException {
|
|
long startTime = System.currentTimeMillis();
|
|
String url = "/api/embed";
|
|
int statusCode = -1;
|
|
Object out = null;
|
|
try {
|
|
String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
|
|
HttpClient httpClient = HttpClient.newHttpClient();
|
|
HttpRequest request =
|
|
HttpRequest.newBuilder(new URI(this.host + url))
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_ACCEPT,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
|
.build();
|
|
HttpResponse<String> response =
|
|
httpClient.send(request, HttpResponse.BodyHandlers.ofString());
|
|
statusCode = response.statusCode();
|
|
String responseBody = response.body();
|
|
if (statusCode == 200) {
|
|
return Utils.getObjectMapper().readValue(responseBody, OllamaEmbedResult.class);
|
|
} else {
|
|
throw new OllamaException(statusCode + " - " + responseBody);
|
|
}
|
|
} catch (InterruptedException e) {
|
|
Thread.currentThread().interrupt();
|
|
throw new OllamaException("Thread was interrupted", e);
|
|
} catch (Exception e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
} finally {
|
|
MetricsRecorder.record(
|
|
url,
|
|
"",
|
|
false,
|
|
ThinkMode.DISABLED,
|
|
false,
|
|
null,
|
|
null,
|
|
startTime,
|
|
statusCode,
|
|
out);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Generates a response from a model using the specified parameters and stream observer.
|
|
* If {@code streamObserver} is provided, streaming is enabled; otherwise, a synchronous call is made.
|
|
*
|
|
* @param request the generation request
|
|
* @param streamObserver the stream observer for streaming responses, or null for synchronous
|
|
* @return the result of the generation
|
|
* @throws OllamaException if the request fails
|
|
*/
|
|
public OllamaResult generate(
|
|
OllamaGenerateRequest request, OllamaGenerateStreamObserver streamObserver)
|
|
throws OllamaException {
|
|
try {
|
|
if (request.isUseTools()) {
|
|
return generateWithToolsInternal(request, streamObserver);
|
|
}
|
|
|
|
if (streamObserver != null) {
|
|
if (!request.getThink().equals(ThinkMode.DISABLED)) {
|
|
return generateSyncForOllamaRequestModel(
|
|
request,
|
|
streamObserver.getThinkingStreamHandler(),
|
|
streamObserver.getResponseStreamHandler());
|
|
} else {
|
|
return generateSyncForOllamaRequestModel(
|
|
request, null, streamObserver.getResponseStreamHandler());
|
|
}
|
|
}
|
|
return generateSyncForOllamaRequestModel(request, null, null);
|
|
} catch (Exception e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
}
|
|
}
|
|
|
|
// (No javadoc for private helper, as is standard)
|
|
private OllamaResult generateWithToolsInternal(
|
|
OllamaGenerateRequest request, OllamaGenerateStreamObserver streamObserver)
|
|
throws OllamaException {
|
|
ArrayList<OllamaChatMessage> msgs = new ArrayList<>();
|
|
OllamaChatRequest chatRequest = new OllamaChatRequest();
|
|
chatRequest.setModel(request.getModel());
|
|
OllamaChatMessage ocm = new OllamaChatMessage();
|
|
ocm.setRole(OllamaChatMessageRole.USER);
|
|
ocm.setResponse(request.getPrompt());
|
|
chatRequest.setMessages(msgs);
|
|
msgs.add(ocm);
|
|
|
|
// Merge request's tools and globally registered tools into a new list to avoid mutating the
|
|
// original request
|
|
List<Tools.Tool> allTools = new ArrayList<>();
|
|
if (request.getTools() != null) {
|
|
allTools.addAll(request.getTools());
|
|
}
|
|
List<Tools.Tool> registeredTools = this.getRegisteredTools();
|
|
if (registeredTools != null) {
|
|
allTools.addAll(registeredTools);
|
|
}
|
|
|
|
OllamaChatTokenHandler hdlr = null;
|
|
chatRequest.setUseTools(true);
|
|
chatRequest.setTools(allTools);
|
|
if (streamObserver != null) {
|
|
chatRequest.setStream(true);
|
|
if (streamObserver.getResponseStreamHandler() != null) {
|
|
hdlr =
|
|
chatResponseModel ->
|
|
streamObserver
|
|
.getResponseStreamHandler()
|
|
.accept(chatResponseModel.getMessage().getResponse());
|
|
}
|
|
}
|
|
OllamaChatResult res = chat(chatRequest, hdlr);
|
|
return new OllamaResult(
|
|
res.getResponseModel().getMessage().getResponse(),
|
|
res.getResponseModel().getMessage().getThinking(),
|
|
res.getResponseModel().getTotalDuration(),
|
|
-1);
|
|
}
|
|
|
|
/**
|
|
* Generates a response from a model asynchronously, returning a streamer for results.
|
|
*
|
|
* @param model the model name
|
|
* @param prompt the prompt to send
|
|
* @param raw whether to use raw mode
|
|
* @param think whether to use "think" mode
|
|
* @return an OllamaAsyncResultStreamer for streaming results
|
|
* @throws OllamaException if the request fails
|
|
*/
|
|
public OllamaAsyncResultStreamer generateAsync(
|
|
String model, String prompt, boolean raw, ThinkMode think) throws OllamaException {
|
|
long startTime = System.currentTimeMillis();
|
|
String url = "/api/generate";
|
|
int statusCode = -1;
|
|
try {
|
|
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
|
|
ollamaRequestModel.setRaw(raw);
|
|
ollamaRequestModel.setThink(think);
|
|
OllamaAsyncResultStreamer ollamaAsyncResultStreamer =
|
|
new OllamaAsyncResultStreamer(
|
|
getRequestBuilderDefault(new URI(this.host + url)),
|
|
ollamaRequestModel,
|
|
requestTimeoutSeconds);
|
|
ollamaAsyncResultStreamer.start();
|
|
statusCode = ollamaAsyncResultStreamer.getHttpStatusCode();
|
|
return ollamaAsyncResultStreamer;
|
|
} catch (Exception e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
} finally {
|
|
MetricsRecorder.record(
|
|
url, model, raw, think, true, null, null, startTime, statusCode, null);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Sends a chat request to a model using an {@link OllamaChatRequest} and sets up streaming response.
|
|
* This can be constructed using an {@link OllamaChatRequest#builder()}.
|
|
*
|
|
* <p>Note: the OllamaChatRequestModel#getStream() property is not implemented.
|
|
*
|
|
* @param request request object to be sent to the server
|
|
* @param tokenHandler callback handler to handle the last token from stream (caution: the
|
|
* previous tokens from stream will not be concatenated)
|
|
* @return {@link OllamaChatResult}
|
|
* @throws OllamaException if the response indicates an error status
|
|
*/
|
|
public OllamaChatResult chat(OllamaChatRequest request, OllamaChatTokenHandler tokenHandler)
|
|
throws OllamaException {
|
|
try {
|
|
OllamaChatEndpointCaller requestCaller =
|
|
new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds);
|
|
OllamaChatResult result;
|
|
|
|
// only add tools if tools flag is set
|
|
if (request.isUseTools()) {
|
|
// add all registered tools to request
|
|
request.getTools().addAll(toolRegistry.getRegisteredTools());
|
|
}
|
|
|
|
if (tokenHandler != null) {
|
|
request.setStream(true);
|
|
result = requestCaller.call(request, tokenHandler);
|
|
} else {
|
|
result = requestCaller.callSync(request);
|
|
}
|
|
|
|
// check if toolCallIsWanted
|
|
List<OllamaChatToolCalls> toolCalls =
|
|
result.getResponseModel().getMessage().getToolCalls();
|
|
int toolCallTries = 0;
|
|
while (toolCalls != null
|
|
&& !toolCalls.isEmpty()
|
|
&& toolCallTries < maxChatToolCallRetries) {
|
|
for (OllamaChatToolCalls toolCall : toolCalls) {
|
|
String toolName = toolCall.getFunction().getName();
|
|
for (Tools.Tool t : request.getTools()) {
|
|
if (t.getToolSpec().getName().equals(toolName)) {
|
|
ToolFunction toolFunction = t.getToolFunction();
|
|
if (toolFunction == null) {
|
|
throw new ToolInvocationException(
|
|
"Tool function not found: " + toolName);
|
|
}
|
|
LOG.debug(
|
|
"Invoking tool {} with arguments: {}",
|
|
toolCall.getFunction().getName(),
|
|
toolCall.getFunction().getArguments());
|
|
Map<String, Object> arguments = toolCall.getFunction().getArguments();
|
|
Object res = toolFunction.apply(arguments);
|
|
String argumentKeys =
|
|
arguments.keySet().stream()
|
|
.map(Object::toString)
|
|
.collect(Collectors.joining(", "));
|
|
request.getMessages()
|
|
.add(
|
|
new OllamaChatMessage(
|
|
OllamaChatMessageRole.TOOL,
|
|
"[TOOL_RESULTS] "
|
|
+ toolName
|
|
+ "("
|
|
+ argumentKeys
|
|
+ "): "
|
|
+ res
|
|
+ " [/TOOL_RESULTS]"));
|
|
}
|
|
}
|
|
}
|
|
if (tokenHandler != null) {
|
|
result = requestCaller.call(request, tokenHandler);
|
|
} else {
|
|
result = requestCaller.callSync(request);
|
|
}
|
|
toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
|
toolCallTries++;
|
|
}
|
|
return result;
|
|
} catch (InterruptedException e) {
|
|
Thread.currentThread().interrupt();
|
|
throw new OllamaException("Thread was interrupted", e);
|
|
} catch (Exception e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Registers a single tool in the tool registry.
|
|
*
|
|
* @param tool the tool to register. Contains the tool's specification and function.
|
|
*/
|
|
public void registerTool(Tools.Tool tool) {
|
|
toolRegistry.addTool(tool);
|
|
LOG.debug("Registered tool: {}", tool.getToolSpec().getName());
|
|
}
|
|
|
|
/**
|
|
* Registers multiple tools in the tool registry.
|
|
*
|
|
* @param tools a list of {@link Tools.Tool} objects to register. Each tool contains its
|
|
* specification and function.
|
|
*/
|
|
public void registerTools(List<Tools.Tool> tools) {
|
|
toolRegistry.addTools(tools);
|
|
}
|
|
|
|
public List<Tools.Tool> getRegisteredTools() {
|
|
return toolRegistry.getRegisteredTools();
|
|
}
|
|
|
|
/**
|
|
* Deregisters all tools from the tool registry. This method removes all registered tools,
|
|
* effectively clearing the registry.
|
|
*/
|
|
public void deregisterTools() {
|
|
toolRegistry.clear();
|
|
LOG.debug("All tools have been deregistered.");
|
|
}
|
|
|
|
/**
|
|
* Registers tools based on the annotations found on the methods of the caller's class and its
|
|
* providers. This method scans the caller's class for the {@link OllamaToolService} annotation
|
|
* and recursively registers annotated tools from all the providers specified in the annotation.
|
|
*
|
|
* @throws OllamaException if the caller's class is not annotated with {@link
|
|
* OllamaToolService} or if reflection-based instantiation or invocation fails
|
|
*/
|
|
public void registerAnnotatedTools() throws OllamaException {
|
|
try {
|
|
Class<?> callerClass = null;
|
|
try {
|
|
callerClass =
|
|
Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
|
|
} catch (ClassNotFoundException e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
}
|
|
|
|
OllamaToolService ollamaToolServiceAnnotation =
|
|
callerClass.getDeclaredAnnotation(OllamaToolService.class);
|
|
if (ollamaToolServiceAnnotation == null) {
|
|
throw new IllegalStateException(
|
|
callerClass + " is not annotated as " + OllamaToolService.class);
|
|
}
|
|
|
|
Class<?>[] providers = ollamaToolServiceAnnotation.providers();
|
|
for (Class<?> provider : providers) {
|
|
registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
|
|
}
|
|
} catch (InstantiationException
|
|
| NoSuchMethodException
|
|
| IllegalAccessException
|
|
| InvocationTargetException e) {
|
|
throw new OllamaException(e.getMessage());
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Registers tools based on the annotations found on the methods of the provided object.
|
|
* This method scans the methods of the given object and registers tools using the {@link ToolSpec}
|
|
* annotation and associated {@link ToolProperty} annotations. It constructs tool specifications
|
|
* and stores them in a tool registry.
|
|
*
|
|
* @param object the object whose methods are to be inspected for annotated tools
|
|
* @throws RuntimeException if any reflection-based instantiation or invocation fails
|
|
*/
|
|
public void registerAnnotatedTools(Object object) {
|
|
Class<?> objectClass = object.getClass();
|
|
Method[] methods = objectClass.getMethods();
|
|
for (Method m : methods) {
|
|
ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
|
|
if (toolSpec == null) {
|
|
continue;
|
|
}
|
|
String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
|
|
String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
|
|
|
|
final Map<String, Tools.Property> params = new HashMap<String, Tools.Property>() {};
|
|
LinkedHashMap<String, String> methodParams = new LinkedHashMap<>();
|
|
for (Parameter parameter : m.getParameters()) {
|
|
final ToolProperty toolPropertyAnn =
|
|
parameter.getDeclaredAnnotation(ToolProperty.class);
|
|
String propType = parameter.getType().getTypeName();
|
|
if (toolPropertyAnn == null) {
|
|
methodParams.put(parameter.getName(), null);
|
|
continue;
|
|
}
|
|
String propName =
|
|
!toolPropertyAnn.name().isBlank()
|
|
? toolPropertyAnn.name()
|
|
: parameter.getName();
|
|
methodParams.put(propName, propType);
|
|
params.put(
|
|
propName,
|
|
Tools.Property.builder()
|
|
.type(propType)
|
|
.description(toolPropertyAnn.desc())
|
|
.required(toolPropertyAnn.required())
|
|
.build());
|
|
}
|
|
Tools.ToolSpec toolSpecification =
|
|
Tools.ToolSpec.builder()
|
|
.name(operationName)
|
|
.description(operationDesc)
|
|
.parameters(Tools.Parameters.of(params))
|
|
.build();
|
|
ReflectionalToolFunction reflectionalToolFunction =
|
|
new ReflectionalToolFunction(object, m, methodParams);
|
|
toolRegistry.addTool(
|
|
Tools.Tool.builder()
|
|
.toolFunction(reflectionalToolFunction)
|
|
.toolSpec(toolSpecification)
|
|
.build());
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Adds a custom role.
|
|
*
|
|
* @param roleName the name of the custom role to be added
|
|
* @return the newly created OllamaChatMessageRole
|
|
*/
|
|
public OllamaChatMessageRole addCustomRole(String roleName) {
|
|
return OllamaChatMessageRole.newCustomRole(roleName);
|
|
}
|
|
|
|
/**
|
|
* Lists all available roles.
|
|
*
|
|
* @return a list of available OllamaChatMessageRole objects
|
|
*/
|
|
public List<OllamaChatMessageRole> listRoles() {
|
|
return OllamaChatMessageRole.getRoles();
|
|
}
|
|
|
|
/**
|
|
* Retrieves a specific role by name.
|
|
*
|
|
* @param roleName the name of the role to retrieve
|
|
* @return the OllamaChatMessageRole associated with the given name
|
|
* @throws RoleNotFoundException if the role with the specified name does not exist
|
|
*/
|
|
public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
|
|
return OllamaChatMessageRole.getRole(roleName);
|
|
}
|
|
|
|
// technical private methods //
|
|
|
|
/**
|
|
* Utility method to encode a file into a Base64 encoded string.
|
|
*
|
|
* @param file the file to be encoded into Base64
|
|
* @return a Base64 encoded string representing the contents of the file
|
|
* @throws IOException if an I/O error occurs during reading the file
|
|
*/
|
|
private static String encodeFileToBase64(File file) throws IOException {
|
|
return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
|
|
}
|
|
|
|
/**
|
|
* Utility method to encode a byte array into a Base64 encoded string.
|
|
*
|
|
* @param bytes the byte array to be encoded into Base64
|
|
* @return a Base64 encoded string representing the byte array
|
|
*/
|
|
private static String encodeByteArrayToBase64(byte[] bytes) {
|
|
return Base64.getEncoder().encodeToString(bytes);
|
|
}
|
|
|
|
/**
|
|
* Generates a request for the Ollama API and returns the result.
|
|
* This method synchronously calls the Ollama API. If a stream handler is provided,
|
|
* the request will be streamed; otherwise, a regular synchronous request will be made.
|
|
*
|
|
* @param ollamaRequestModel the request model containing necessary parameters for the Ollama API request
|
|
* @param thinkingStreamHandler the stream handler for "thinking" tokens, or null if not used
|
|
* @param responseStreamHandler the stream handler to process streaming responses, or null for non-streaming requests
|
|
* @return the result of the Ollama API request
|
|
* @throws OllamaException if the request fails due to an issue with the Ollama API
|
|
*/
|
|
private OllamaResult generateSyncForOllamaRequestModel(
|
|
OllamaGenerateRequest ollamaRequestModel,
|
|
OllamaGenerateTokenHandler thinkingStreamHandler,
|
|
OllamaGenerateTokenHandler responseStreamHandler)
|
|
throws OllamaException {
|
|
long startTime = System.currentTimeMillis();
|
|
int statusCode = -1;
|
|
Object out = null;
|
|
try {
|
|
OllamaGenerateEndpointCaller requestCaller =
|
|
new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds);
|
|
OllamaResult result;
|
|
if (responseStreamHandler != null) {
|
|
ollamaRequestModel.setStream(true);
|
|
result =
|
|
requestCaller.call(
|
|
ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
|
|
} else {
|
|
result = requestCaller.callSync(ollamaRequestModel);
|
|
}
|
|
statusCode = result.getHttpStatusCode();
|
|
out = result;
|
|
return result;
|
|
} catch (InterruptedException e) {
|
|
Thread.currentThread().interrupt();
|
|
throw new OllamaException("Thread was interrupted", e);
|
|
} catch (Exception e) {
|
|
throw new OllamaException(e.getMessage(), e);
|
|
} finally {
|
|
MetricsRecorder.record(
|
|
OllamaGenerateEndpointCaller.endpoint,
|
|
ollamaRequestModel.getModel(),
|
|
ollamaRequestModel.isRaw(),
|
|
ollamaRequestModel.getThink(),
|
|
ollamaRequestModel.isStream(),
|
|
ollamaRequestModel.getOptions(),
|
|
ollamaRequestModel.getFormat(),
|
|
startTime,
|
|
statusCode,
|
|
out);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Get default request builder.
|
|
*
|
|
* @param uri URI to get a HttpRequest.Builder
|
|
* @return HttpRequest.Builder
|
|
*/
|
|
private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
|
|
HttpRequest.Builder requestBuilder =
|
|
HttpRequest.newBuilder(uri)
|
|
.header(
|
|
Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE,
|
|
Constants.HttpConstants.APPLICATION_JSON)
|
|
.timeout(Duration.ofSeconds(requestTimeoutSeconds));
|
|
if (isAuthSet()) {
|
|
requestBuilder.header("Authorization", auth.getAuthHeaderValue());
|
|
}
|
|
return requestBuilder;
|
|
}
|
|
|
|
/**
|
|
* Check if auth param is set.
|
|
*
|
|
* @return true when auth param is set
|
|
*/
|
|
private boolean isAuthSet() {
|
|
return auth != null;
|
|
}
|
|
}
|