Merge remote-tracking branch 'origin/main'

This commit is contained in:
Amith Koujalgi 2023-12-14 17:30:27 +05:30
commit f6af7025a8
6 changed files with 558 additions and 378 deletions

View File

@ -91,7 +91,13 @@ For simplest way to get started, I prefer to use the Ollama docker setup.
Start the Ollama docker container: Start the Ollama docker container:
```shell ```shell
docker run -v ~/ollama:/root/.ollama -p 11434:11434 ollama/ollama docker run -it -v ~/ollama:/root/.ollama -p 11434:11434 ollama/ollama
```
With GPUs
```shell
docker run -it --gpus=all -v ~/ollama:/root/.ollama -p 11434:11434 ollama/ollama
``` ```
Instantiate `OllamaAPI` Instantiate `OllamaAPI`
@ -347,6 +353,7 @@ Find the full `Javadoc` (API specifications) [here](https://amithkoujalgi.github
conversational memory conversational memory
- `stream`: Add support for streaming responses from the model - `stream`: Add support for streaming responses from the model
- [x] Setup logging - [x] Setup logging
- [ ] Use lombok
- [ ] Add test cases - [ ] Add test cases
- [ ] Handle exceptions better (maybe throw more appropriate exceptions) - [ ] Handle exceptions better (maybe throw more appropriate exceptions)

View File

@ -3,9 +3,6 @@ package io.github.amithkoujalgi.ollama4j.core;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.*; import io.github.amithkoujalgi.ollama4j.core.models.*;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -17,226 +14,298 @@ import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.List; import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** /** The base Ollama API class. */
* The base Ollama API class.
*/
@SuppressWarnings("DuplicatedCode") @SuppressWarnings("DuplicatedCode")
public class OllamaAPI { public class OllamaAPI {
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
private final String host; private final String host;
private boolean verbose = false; private boolean verbose = false;
/** /**
* Instantiates the Ollama API. * Instantiates the Ollama API.
* *
* @param host the host address of Ollama server * @param host the host address of Ollama server
*/ */
public OllamaAPI(String host) { public OllamaAPI(String host) {
if (host.endsWith("/")) { if (host.endsWith("/")) {
this.host = host.substring(0, host.length() - 1); this.host = host.substring(0, host.length() - 1);
} else { } else {
this.host = host; this.host = host;
}
} }
}
/** /**
* Set/unset logging of responses * Set/unset logging of responses
* @param verbose true/false *
*/ * @param verbose true/false
public void setVerbose(boolean verbose) { */
this.verbose = verbose; public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
/**
* List available models from Ollama server.
*
* @return the list
*/
public List<Model> listModels()
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = this.host + "/api/tags";
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest =
HttpRequest.newBuilder()
.uri(new URI(url))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.GET()
.build();
HttpResponse<String> response =
httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
if (statusCode == 200) {
return Utils.getObjectMapper()
.readValue(responseString, ListModelsResponse.class)
.getModels();
} else {
throw new OllamaBaseException(statusCode + " - " + responseString);
} }
}
/** /**
* List available models from Ollama server. * Pull a model on the Ollama server from the list of <a
* * href="https://ollama.ai/library">available models</a>.
* @return the list *
*/ * @param model the name of the model
public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { */
String url = this.host + "/api/tags"; public void pullModel(String model)
HttpClient httpClient = HttpClient.newHttpClient(); throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
HttpRequest httpRequest = HttpRequest.newBuilder().uri(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); String url = this.host + "/api/pull";
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); String jsonData = String.format("{\"name\": \"%s\"}", model);
int statusCode = response.statusCode(); HttpRequest request =
String responseString = response.body(); HttpRequest.newBuilder()
if (statusCode == 200) { .uri(new URI(url))
return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class).getModels(); .POST(HttpRequest.BodyPublishers.ofString(jsonData))
} else { .header("Accept", "application/json")
throw new OllamaBaseException(statusCode + " - " + responseString); .header("Content-type", "application/json")
} .build();
} HttpClient client = HttpClient.newHttpClient();
HttpResponse<InputStream> response =
/** client.send(request, HttpResponse.BodyHandlers.ofInputStream());
* Pull a model on the Ollama server from the list of <a href="https://ollama.ai/library">available models</a>. int statusCode = response.statusCode();
* InputStream responseBodyStream = response.body();
* @param model the name of the model String responseString = "";
*/ try (BufferedReader reader =
public void pullModel(String model) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String url = this.host + "/api/pull"; String line;
String jsonData = String.format("{\"name\": \"%s\"}", model); while ((line = reader.readLine()) != null) {
HttpRequest request = HttpRequest.newBuilder().uri(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json").header("Content-type", "application/json").build(); ModelPullResponse modelPullResponse =
HttpClient client = HttpClient.newHttpClient(); Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
String responseString = "";
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
if (verbose) {
logger.info(modelPullResponse.getStatus());
}
}
}
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
}
/**
* Gets model details from the Ollama server.
*
* @param modelName the model
* @return the model details
*/
public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException {
String url = this.host + "/api/show";
String jsonData = String.format("{\"name\": \"%s\"}", modelName);
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseBody = response.body();
if (statusCode == 200) {
return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
} else {
throw new OllamaBaseException(statusCode + " - " + responseBody);
}
}
/**
* Create a custom model from a model file.
* Read more about custom model file creation <a href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md">here</a>.
*
* @param modelName the name of the custom model to be created.
* @param modelFilePath the path to model file that exists on the Ollama server.
*/
public void createModel(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/create";
String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath);
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
// FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this if the issue is fixed in the Ollama API server.
if (responseString.contains("error")) {
throw new OllamaBaseException(responseString);
}
if (verbose) { if (verbose) {
logger.info(responseString); logger.info(modelPullResponse.getStatus());
} }
}
} }
if (statusCode != 200) {
/** throw new OllamaBaseException(statusCode + " - " + responseString);
* Delete a model from Ollama server.
*
* @param name the name of the model to be deleted.
* @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama server.
*/
public void deleteModel(String name, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/delete";
String jsonData = String.format("{\"name\": \"%s\"}", name);
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseBody = response.body();
if (statusCode == 404 && responseBody.contains("model") && responseBody.contains("not found")) {
return;
}
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseBody);
}
} }
}
/** /**
* Generate embeddings for a given text from a model * Gets model details from the Ollama server.
* *
* @param model name of model to generate embeddings from * @param modelName the model
* @param prompt text to generate embeddings for * @return the model details
* @return embeddings */
*/ public ModelDetail getModelDetails(String modelName)
public List<Double> generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException { throws IOException, OllamaBaseException, InterruptedException {
String url = this.host + "/api/embeddings"; String url = this.host + "/api/show";
String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt); String jsonData = String.format("{\"name\": \"%s\"}", modelName);
HttpClient httpClient = HttpClient.newHttpClient(); HttpRequest request =
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpRequest.newBuilder()
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); .uri(URI.create(url))
int statusCode = response.statusCode(); .header("Accept", "application/json")
String responseBody = response.body(); .header("Content-type", "application/json")
if (statusCode == 200) { .POST(HttpRequest.BodyPublishers.ofString(jsonData))
EmbeddingResponse embeddingResponse = Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class); .build();
return embeddingResponse.getEmbedding(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseBody = response.body();
if (statusCode == 200) {
return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
} else {
throw new OllamaBaseException(statusCode + " - " + responseBody);
}
}
/**
* Create a custom model from a model file. Read more about custom model file creation <a
* href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md">here</a>.
*
* @param modelName the name of the custom model to be created.
* @param modelFilePath the path to model file that exists on the Ollama server.
*/
public void createModel(String modelName, String modelFilePath)
throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/create";
String jsonData =
String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath);
HttpRequest request =
HttpRequest.newBuilder()
.uri(URI.create(url))
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
// FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this
// if the issue is fixed in the Ollama API server.
if (responseString.contains("error")) {
throw new OllamaBaseException(responseString);
}
if (verbose) {
logger.info(responseString);
}
}
/**
* Delete a model from Ollama server.
*
* @param name the name of the model to be deleted.
* @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama
* server.
*/
public void deleteModel(String name, boolean ignoreIfNotPresent)
throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/delete";
String jsonData = String.format("{\"name\": \"%s\"}", name);
HttpRequest request =
HttpRequest.newBuilder()
.uri(URI.create(url))
.method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseBody = response.body();
if (statusCode == 404 && responseBody.contains("model") && responseBody.contains("not found")) {
return;
}
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseBody);
}
}
/**
* Generate embeddings for a given text from a model
*
* @param model name of model to generate embeddings from
* @param prompt text to generate embeddings for
* @return embeddings
*/
public List<Double> generateEmbeddings(String model, String prompt)
throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/embeddings";
String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt);
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest request =
HttpRequest.newBuilder()
.uri(URI.create(url))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
.build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseBody = response.body();
if (statusCode == 200) {
EmbeddingResponse embeddingResponse =
Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class);
return embeddingResponse.getEmbedding();
} else {
throw new OllamaBaseException(statusCode + " - " + responseBody);
}
}
/**
* Ask a question to a model running on Ollama server. This is a sync/blocking call.
*
* @param ollamaModelType the ollama model to ask the question to
* @param promptText the prompt/question text
* @return OllamaResult - that includes response text and time taken for response
*/
public OllamaResult ask(String ollamaModelType, String promptText)
throws OllamaBaseException, IOException, InterruptedException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate");
HttpRequest request =
HttpRequest.newBuilder(uri)
.POST(
HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
.header("Content-Type", "application/json")
.build();
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else { } else {
throw new OllamaBaseException(statusCode + " - " + responseBody); OllamaResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
if (!ollamaResponseModel.getDone()) {
responseBuffer.append(ollamaResponseModel.getResponse());
}
} }
}
} }
if (statusCode != 200) {
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
}
}
/** /**
* Ask a question to a model running on Ollama server. This is a sync/blocking call. * Ask a question to a model running on Ollama server and get a callback handle that can be used
* * to check for status and get the response from the model later. This would be an
* @param ollamaModelType the ollama model to ask the question to * async/non-blocking call.
* @param promptText the prompt/question text *
* @return OllamaResult - that includes response text and time taken for response * @param ollamaModelType the ollama model to ask the question to
*/ * @param promptText the prompt/question text
public OllamaResult ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException { * @return the ollama async result callback handle
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); */
long startTime = System.currentTimeMillis(); public OllamaAsyncResultCallback askAsync(String ollamaModelType, String promptText) {
HttpClient httpClient = HttpClient.newHttpClient(); OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
URI uri = URI.create(this.host + "/api/generate"); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build(); URI uri = URI.create(this.host + "/api/generate");
HttpResponse<InputStream> response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); OllamaAsyncResultCallback ollamaAsyncResultCallback =
int statusCode = response.statusCode(); new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel);
InputStream responseBodyStream = response.body(); ollamaAsyncResultCallback.start();
StringBuilder responseBuffer = new StringBuilder(); return ollamaAsyncResultCallback;
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { }
String line;
while ((line = reader.readLine()) != null) {
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
if (!ollamaResponseModel.getDone()) {
responseBuffer.append(ollamaResponseModel.getResponse());
}
}
}
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseBuffer);
} else {
long endTime = System.currentTimeMillis();
return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime);
}
}
/**
* Ask a question to a model running on Ollama server and get a callback handle that can be used to check for status and get the response from the model later.
* This would be an async/non-blocking call.
*
* @param ollamaModelType the ollama model to ask the question to
* @param promptText the prompt/question text
* @return the ollama async result callback handle
*/
public OllamaAsyncResultCallback askAsync(String ollamaModelType, String promptText) {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel);
ollamaAsyncResultCallback.start();
return ollamaAsyncResultCallback;
}
} }

View File

@ -2,7 +2,6 @@ package io.github.amithkoujalgi.ollama4j.core.models;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -17,79 +16,129 @@ import java.util.Queue;
@SuppressWarnings("unused") @SuppressWarnings("unused")
public class OllamaAsyncResultCallback extends Thread { public class OllamaAsyncResultCallback extends Thread {
private final HttpClient client; private final HttpClient client;
private final URI uri; private final URI uri;
private final OllamaRequestModel ollamaRequestModel; private final OllamaRequestModel ollamaRequestModel;
private final Queue<String> queue = new LinkedList<>(); private final Queue<String> queue = new LinkedList<>();
private String result; private String result;
private boolean isDone; private boolean isDone;
private long responseTime = 0; private boolean succeeded;
public OllamaAsyncResultCallback(HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) { private int httpStatusCode;
this.client = client; private long responseTime = 0;
this.ollamaRequestModel = ollamaRequestModel;
this.uri = uri;
this.isDone = false;
this.result = "";
this.queue.add("");
}
@Override public OllamaAsyncResultCallback(
public void run() { HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) {
try { this.client = client;
long startTime = System.currentTimeMillis(); this.ollamaRequestModel = ollamaRequestModel;
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build(); this.uri = uri;
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); this.isDone = false;
int statusCode = response.statusCode(); this.result = "";
this.queue.add("");
}
InputStream responseBodyStream = response.body(); @Override
String responseString = ""; public void run() {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { try {
String line; long startTime = System.currentTimeMillis();
StringBuilder responseBuffer = new StringBuilder(); HttpRequest request =
while ((line = reader.readLine()) != null) { HttpRequest.newBuilder(uri)
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); .POST(
queue.add(ollamaResponseModel.getResponse()); HttpRequest.BodyPublishers.ofString(
if (!ollamaResponseModel.getDone()) { Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
responseBuffer.append(ollamaResponseModel.getResponse()); .header("Content-Type", "application/json")
} .build();
} HttpResponse<InputStream> response =
reader.close(); client.send(request, HttpResponse.BodyHandlers.ofInputStream());
this.isDone = true; int statusCode = response.statusCode();
this.result = responseBuffer.toString(); this.httpStatusCode = statusCode;
long endTime = System.currentTimeMillis();
responseTime = endTime - startTime; InputStream responseBodyStream = response.body();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
StringBuilder responseBuffer = new StringBuilder();
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
queue.add(ollamaResponseModel.getError());
responseBuffer.append(ollamaResponseModel.getError());
} else {
OllamaResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
queue.add(ollamaResponseModel.getResponse());
if (!ollamaResponseModel.getDone()) {
responseBuffer.append(ollamaResponseModel.getResponse());
} }
if (statusCode != 200) { }
throw new OllamaBaseException(statusCode + " - " + responseString);
}
} catch (IOException | InterruptedException | OllamaBaseException e) {
this.isDone = true;
this.result = "FAILED! " + e.getMessage();
} }
} reader.close();
public boolean isComplete() { this.isDone = true;
return isDone; this.succeeded = true;
this.result = responseBuffer.toString();
long endTime = System.currentTimeMillis();
responseTime = endTime - startTime;
}
if (statusCode != 200) {
throw new OllamaBaseException(this.result);
}
} catch (IOException | InterruptedException | OllamaBaseException e) {
this.isDone = true;
this.succeeded = false;
this.result = "[FAILED] " + e.getMessage();
} }
}
/** /**
* Returns the final response when the execution completes. Does not return intermediate results. * Returns the status of the thread. This does not indicate that the request was successful or a
* @return response text * failure, rather it is just a status flag to indicate if the thread is active or ended.
*/ *
public String getResponse() { * @return boolean - status
return result; */
} public boolean isComplete() {
return isDone;
}
public Queue<String> getStream() { /**
return queue; * Returns the HTTP response status code for the request that was made to Ollama server.
} *
* @return int - the status code for the request
*/
public int getHttpStatusCode() {
return httpStatusCode;
}
/** /**
* Returns the response time in seconds. * Returns the status of the request. Indicates if the request was successful or a failure. If the
* @return response time in seconds * request was a failure, the `getResponse()` method will return the error message.
*/ *
public long getResponseTime() { * @return boolean - status
return responseTime; */
} public boolean isSucceeded() {
return succeeded;
}
/**
* Returns the final response when the execution completes. Does not return intermediate results.
*
* @return String - response text
*/
public String getResponse() {
return result;
}
public Queue<String> getStream() {
return queue;
}
/**
* Returns the response time in milliseconds.
*
* @return long - response time in milliseconds.
*/
public long getResponseTime() {
return responseTime;
}
} }

View File

@ -0,0 +1,18 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
@JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaErrorResponseModel {
private String error;
public String getError() {
return error;
}
public void setError(String error) {
this.error = error;
}
}

View File

@ -1,20 +1,57 @@
package io.github.amithkoujalgi.ollama4j.core.models; package io.github.amithkoujalgi.ollama4j.core.models;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException;
/** The type Ollama result. */
@SuppressWarnings("unused") @SuppressWarnings("unused")
public class OllamaResult { public class OllamaResult {
private String response; private final String response;
private long responseTime = 0;
public OllamaResult(String response, long responseTime) { private int httpStatusCode;
this.response = response;
this.responseTime = responseTime;
}
public String getResponse() { private long responseTime = 0;
return response;
}
public long getResponseTime() { public OllamaResult(String response, long responseTime, int httpStatusCode) {
return responseTime; this.response = response;
this.responseTime = responseTime;
this.httpStatusCode = httpStatusCode;
}
/**
* Get the response text
*
* @return String - response text
*/
public String getResponse() {
return response;
}
/**
* Get the response time in milliseconds.
*
* @return long - response time in milliseconds
*/
public long getResponseTime() {
return responseTime;
}
/**
* Get the response status code.
*
* @return int - response status code
*/
public int getHttpStatusCode() {
return httpStatusCode;
}
@Override
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
} }
}
} }

View File

@ -1,121 +1,121 @@
package io.github.amithkoujalgi.ollama4j.unittests; package io.github.amithkoujalgi.ollama4j.unittests;
import static org.mockito.Mockito.*;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback; import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import java.io.IOException; import java.io.IOException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.ArrayList; import java.util.ArrayList;
import org.junit.jupiter.api.Test;
import static org.mockito.Mockito.*; import org.mockito.Mockito;
public class TestMockedAPIs { public class TestMockedAPIs {
@Test @Test
public void testMockPullModel() { public void testMockPullModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
try { try {
doNothing().when(ollamaAPI).pullModel(model); doNothing().when(ollamaAPI).pullModel(model);
ollamaAPI.pullModel(model); ollamaAPI.pullModel(model);
verify(ollamaAPI, times(1)).pullModel(model); verify(ollamaAPI, times(1)).pullModel(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testListModels() { public void testListModels() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
try { try {
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>()); when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
ollamaAPI.listModels(); ollamaAPI.listModels();
verify(ollamaAPI, times(1)).listModels(); verify(ollamaAPI, times(1)).listModels();
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testCreateModel() { public void testCreateModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String modelFilePath = "/somemodel"; String modelFilePath = "/somemodel";
try { try {
doNothing().when(ollamaAPI).createModel(model, modelFilePath); doNothing().when(ollamaAPI).createModel(model, modelFilePath);
ollamaAPI.createModel(model, modelFilePath); ollamaAPI.createModel(model, modelFilePath);
verify(ollamaAPI, times(1)).createModel(model, modelFilePath); verify(ollamaAPI, times(1)).createModel(model, modelFilePath);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testDeleteModel() { public void testDeleteModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
try { try {
doNothing().when(ollamaAPI).deleteModel(model, true); doNothing().when(ollamaAPI).deleteModel(model, true);
ollamaAPI.deleteModel(model, true); ollamaAPI.deleteModel(model, true);
verify(ollamaAPI, times(1)).deleteModel(model, true); verify(ollamaAPI, times(1)).deleteModel(model, true);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testGetModelDetails() { public void testGetModelDetails() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
try { try {
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail()); when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
ollamaAPI.getModelDetails(model); ollamaAPI.getModelDetails(model);
verify(ollamaAPI, times(1)).getModelDetails(model); verify(ollamaAPI, times(1)).getModelDetails(model);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testGenerateEmbeddings() { public void testGenerateEmbeddings() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
try { try {
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>()); when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
ollamaAPI.generateEmbeddings(model, prompt); ollamaAPI.generateEmbeddings(model, prompt);
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt); verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testAsk() { public void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
try { try {
when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0)); when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.ask(model, prompt); ollamaAPI.ask(model, prompt);
verify(ollamaAPI, times(1)).ask(model, prompt); verify(ollamaAPI, times(1)).ask(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testAskAsync() { public void testAskAsync() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
when(ollamaAPI.askAsync(model, prompt)).thenReturn(new OllamaAsyncResultCallback(null, null, null)); when(ollamaAPI.askAsync(model, prompt))
ollamaAPI.askAsync(model, prompt); .thenReturn(new OllamaAsyncResultCallback(null, null, null));
verify(ollamaAPI, times(1)).askAsync(model, prompt); ollamaAPI.askAsync(model, prompt);
} verify(ollamaAPI, times(1)).askAsync(model, prompt);
}
} }