From 6f7f3496190dc45836e761cd15263892bc981dda Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Thu, 14 Dec 2023 15:22:13 +0530 Subject: [PATCH 1/7] Update README.md --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 5a6cfd3..bfc8793 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,11 @@ Start the Ollama docker container: docker run -v ~/ollama:/root/.ollama -p 11434:11434 ollama/ollama ``` +With GPUs +``` +docker run -d --gpus=all -v ~/ollama:/root/.ollama -p 11434:11434 ollama/ollama +``` + Instantiate `OllamaAPI` ```java From 417423005a3ed3c9d85176c3c0dbd9616a7bd154 Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Thu, 14 Dec 2023 15:35:44 +0530 Subject: [PATCH 2/7] updated readme --- README.md | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index bfc8793..4241e89 100644 --- a/README.md +++ b/README.md @@ -91,12 +91,20 @@ For simplest way to get started, I prefer to use the Ollama docker setup. Start the Ollama docker container: ```shell -docker run -v ~/ollama:/root/.ollama -p 11434:11434 ollama/ollama +docker run -it \ + -v ~/ollama:/root/.ollama \ + -p 11434:11434 \ + ollama/ollama ``` With GPUs -``` -docker run -d --gpus=all -v ~/ollama:/root/.ollama -p 11434:11434 ollama/ollama + +```shell +docker run -it \ + --gpus=all \ + -v ~/ollama:/root/.ollama \ + -p 11434:11434 + ollama/ollama ``` Instantiate `OllamaAPI` From 792222c162178e8ae2f9514d592a720a1afab116 Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Thu, 14 Dec 2023 15:37:12 +0530 Subject: [PATCH 3/7] updated readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4241e89..e44402d 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,7 @@ With GPUs docker run -it \ --gpus=all \ -v ~/ollama:/root/.ollama \ - -p 11434:11434 + -p 11434:11434 \ ollama/ollama ``` From a3c59c32ef02993e7ff9e27e0b62363e980b0a2d Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Thu, 14 Dec 2023 15:38:23 +0530 Subject: [PATCH 4/7] updated readme --- README.md | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index e44402d..14ddb02 100644 --- a/README.md +++ b/README.md @@ -91,20 +91,13 @@ For simplest way to get started, I prefer to use the Ollama docker setup. Start the Ollama docker container: ```shell -docker run -it \ - -v ~/ollama:/root/.ollama \ - -p 11434:11434 \ - ollama/ollama +docker run -it -v ~/ollama:/root/.ollama -p 11434:11434 ollama/ollama ``` With GPUs ```shell -docker run -it \ - --gpus=all \ - -v ~/ollama:/root/.ollama \ - -p 11434:11434 \ - ollama/ollama +docker run -it --gpus=all -v ~/ollama:/root/.ollama -p 11434:11434 ollama/ollama ``` Instantiate `OllamaAPI` From f67f3b9eb558cf44a37b5706525b62449b73f425 Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Thu, 14 Dec 2023 16:40:15 +0530 Subject: [PATCH 5/7] updated readme --- README.md | 1 + .../models/OllamaAsyncResultCallback.java | 181 +++++++++++------- .../core/models/OllamaErrorResponseModel.java | 18 ++ .../ollama4j/core/models/OllamaResult.java | 47 +++-- 4 files changed, 170 insertions(+), 77 deletions(-) create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaErrorResponseModel.java diff --git a/README.md b/README.md index 14ddb02..0a5ac9b 100644 --- a/README.md +++ b/README.md @@ -353,6 +353,7 @@ Find the full `Javadoc` (API specifications) [here](https://amithkoujalgi.github conversational memory - `stream`: Add support for streaming responses from the model - [x] Setup logging +- [ ] Use lombok - [ ] Add test cases - [ ] Handle exceptions better (maybe throw more appropriate exceptions) diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java index 0d92d5c..4683b61 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java @@ -2,7 +2,6 @@ package io.github.amithkoujalgi.ollama4j.core.models; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; - import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; @@ -17,79 +16,129 @@ import java.util.Queue; @SuppressWarnings("unused") public class OllamaAsyncResultCallback extends Thread { - private final HttpClient client; - private final URI uri; - private final OllamaRequestModel ollamaRequestModel; - private final Queue queue = new LinkedList<>(); - private String result; - private boolean isDone; - private long responseTime = 0; + private final HttpClient client; + private final URI uri; + private final OllamaRequestModel ollamaRequestModel; + private final Queue queue = new LinkedList<>(); + private String result; + private boolean isDone; + private boolean succeeded; - public OllamaAsyncResultCallback(HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) { - this.client = client; - this.ollamaRequestModel = ollamaRequestModel; - this.uri = uri; - this.isDone = false; - this.result = ""; - this.queue.add(""); - } + private int httpStatusCode; + private long responseTime = 0; - @Override - public void run() { - try { - long startTime = System.currentTimeMillis(); - HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build(); - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); - int statusCode = response.statusCode(); + public OllamaAsyncResultCallback( + HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) { + this.client = client; + this.ollamaRequestModel = ollamaRequestModel; + this.uri = uri; + this.isDone = false; + this.result = ""; + this.queue.add(""); + } - InputStream responseBodyStream = response.body(); - String responseString = ""; - try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { - String line; - StringBuilder responseBuffer = new StringBuilder(); - while ((line = reader.readLine()) != null) { - OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); - queue.add(ollamaResponseModel.getResponse()); - if (!ollamaResponseModel.getDone()) { - responseBuffer.append(ollamaResponseModel.getResponse()); - } - } - reader.close(); - this.isDone = true; - this.result = responseBuffer.toString(); - long endTime = System.currentTimeMillis(); - responseTime = endTime - startTime; + @Override + public void run() { + try { + long startTime = System.currentTimeMillis(); + HttpRequest request = + HttpRequest.newBuilder(uri) + .POST( + HttpRequest.BodyPublishers.ofString( + Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) + .header("Content-Type", "application/json") + .build(); + HttpResponse response = + client.send(request, HttpResponse.BodyHandlers.ofInputStream()); + int statusCode = response.statusCode(); + this.httpStatusCode = statusCode; + + InputStream responseBodyStream = response.body(); + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + String line; + StringBuilder responseBuffer = new StringBuilder(); + while ((line = reader.readLine()) != null) { + if (statusCode == 404) { + OllamaErrorResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); + queue.add(ollamaResponseModel.getError()); + responseBuffer.append(ollamaResponseModel.getError()); + } else { + OllamaResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); + queue.add(ollamaResponseModel.getResponse()); + if (!ollamaResponseModel.getDone()) { + responseBuffer.append(ollamaResponseModel.getResponse()); } - if (statusCode != 200) { - throw new OllamaBaseException(statusCode + " - " + responseString); - } - } catch (IOException | InterruptedException | OllamaBaseException e) { - this.isDone = true; - this.result = "FAILED! " + e.getMessage(); + } } - } + reader.close(); - public boolean isComplete() { - return isDone; + this.isDone = true; + this.succeeded = true; + this.result = responseBuffer.toString(); + long endTime = System.currentTimeMillis(); + responseTime = endTime - startTime; + } + if (statusCode != 200) { + throw new OllamaBaseException(this.result); + } + } catch (IOException | InterruptedException | OllamaBaseException e) { + this.isDone = true; + this.succeeded = false; + this.result = "[FAILED] " + e.getMessage(); } + } - /** - * Returns the final response when the execution completes. Does not return intermediate results. - * @return response text - */ - public String getResponse() { - return result; - } + /** + * Returns the status of the thread. This does not indicate that the request was successful or a + * failure, rather it is just a status flag to indicate if the thread is active or ended. + * + * @return boolean - status + */ + public boolean isComplete() { + return isDone; + } - public Queue getStream() { - return queue; - } + /** + * Returns the HTTP response status code for the request that was made to Ollama server. + * + * @return int - the status code for the request + */ + public int getHttpStatusCode() { + return httpStatusCode; + } - /** - * Returns the response time in seconds. - * @return response time in seconds - */ - public long getResponseTime() { - return responseTime; - } + /** + * Returns the status of the request. Indicates if the request was successful or a failure. If the + * request was a failure, the `getResponse()` method will return the error message. + * + * @return boolean - status + */ + public boolean isSucceeded() { + return succeeded; + } + + /** + * Returns the final response when the execution completes. Does not return intermediate results. + * + * @return String - response text + */ + public String getResponse() { + return result; + } + + public Queue getStream() { + return queue; + } + + /** + * Returns the response time in milliseconds. + * + * @return long - response time in milliseconds. + */ + public long getResponseTime() { + return responseTime; + } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaErrorResponseModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaErrorResponseModel.java new file mode 100644 index 0000000..c5e829a --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaErrorResponseModel.java @@ -0,0 +1,18 @@ +package io.github.amithkoujalgi.ollama4j.core.models; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class OllamaErrorResponseModel { + private String error; + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResult.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResult.java index ac5f67b..9bdb246 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResult.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResult.java @@ -1,20 +1,45 @@ package io.github.amithkoujalgi.ollama4j.core.models; +import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; + +import com.fasterxml.jackson.core.JsonProcessingException; + +/** The type Ollama result. */ @SuppressWarnings("unused") public class OllamaResult { - private String response; - private long responseTime = 0; + private final String response; - public OllamaResult(String response, long responseTime) { - this.response = response; - this.responseTime = responseTime; - } + private long responseTime = 0; - public String getResponse() { - return response; - } + public OllamaResult(String response, long responseTime) { + this.response = response; + this.responseTime = responseTime; + } - public long getResponseTime() { - return responseTime; + /** + * Get the response text + * + * @return String - response text + */ + public String getResponse() { + return response; + } + + /** + * Get the response time in milliseconds. + * + * @return long - response time in milliseconds + */ + public long getResponseTime() { + return responseTime; + } + + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); } + } } From 4e4a5d2996780977490e3d4d49fc3a01f222484b Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Thu, 14 Dec 2023 16:45:12 +0530 Subject: [PATCH 6/7] Fixes to ask API --- .../ollama4j/core/OllamaAPI.java | 483 ++++++++++-------- .../ollama4j/core/models/OllamaResult.java | 14 +- 2 files changed, 289 insertions(+), 208 deletions(-) diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 3e3e5db..8a9b10e 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -3,9 +3,6 @@ package io.github.amithkoujalgi.ollama4j.core; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.*; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; @@ -17,226 +14,298 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.charset.StandardCharsets; import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -/** - * The base Ollama API class. - */ +/** The base Ollama API class. */ @SuppressWarnings("DuplicatedCode") public class OllamaAPI { - private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); - private final String host; - private boolean verbose = false; + private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); + private final String host; + private boolean verbose = false; - /** - * Instantiates the Ollama API. - * - * @param host the host address of Ollama server - */ - public OllamaAPI(String host) { - if (host.endsWith("/")) { - this.host = host.substring(0, host.length() - 1); - } else { - this.host = host; - } + /** + * Instantiates the Ollama API. + * + * @param host the host address of Ollama server + */ + public OllamaAPI(String host) { + if (host.endsWith("/")) { + this.host = host.substring(0, host.length() - 1); + } else { + this.host = host; } + } - /** - * Set/unset logging of responses - * @param verbose true/false - */ - public void setVerbose(boolean verbose) { - this.verbose = verbose; + /** + * Set/unset logging of responses + * + * @param verbose true/false + */ + public void setVerbose(boolean verbose) { + this.verbose = verbose; + } + + /** + * List available models from Ollama server. + * + * @return the list + */ + public List listModels() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + String url = this.host + "/api/tags"; + HttpClient httpClient = HttpClient.newHttpClient(); + HttpRequest httpRequest = + HttpRequest.newBuilder() + .uri(new URI(url)) + .header("Accept", "application/json") + .header("Content-type", "application/json") + .GET() + .build(); + HttpResponse response = + httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); + int statusCode = response.statusCode(); + String responseString = response.body(); + if (statusCode == 200) { + return Utils.getObjectMapper() + .readValue(responseString, ListModelsResponse.class) + .getModels(); + } else { + throw new OllamaBaseException(statusCode + " - " + responseString); } + } - /** - * List available models from Ollama server. - * - * @return the list - */ - public List listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { - String url = this.host + "/api/tags"; - HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = HttpRequest.newBuilder().uri(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); - HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); - int statusCode = response.statusCode(); - String responseString = response.body(); - if (statusCode == 200) { - return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class).getModels(); - } else { - throw new OllamaBaseException(statusCode + " - " + responseString); - } - } - - /** - * Pull a model on the Ollama server from the list of available models. - * - * @param model the name of the model - */ - public void pullModel(String model) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - String url = this.host + "/api/pull"; - String jsonData = String.format("{\"name\": \"%s\"}", model); - HttpRequest request = HttpRequest.newBuilder().uri(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json").header("Content-type", "application/json").build(); - HttpClient client = HttpClient.newHttpClient(); - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); - int statusCode = response.statusCode(); - InputStream responseBodyStream = response.body(); - String responseString = ""; - try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { - String line; - while ((line = reader.readLine()) != null) { - ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class); - if (verbose) { - logger.info(modelPullResponse.getStatus()); - } - } - } - if (statusCode != 200) { - throw new OllamaBaseException(statusCode + " - " + responseString); - } - } - - /** - * Gets model details from the Ollama server. - * - * @param modelName the model - * @return the model details - */ - public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException { - String url = this.host + "/api/show"; - String jsonData = String.format("{\"name\": \"%s\"}", modelName); - HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); - HttpClient client = HttpClient.newHttpClient(); - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); - int statusCode = response.statusCode(); - String responseBody = response.body(); - if (statusCode == 200) { - return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class); - } else { - throw new OllamaBaseException(statusCode + " - " + responseBody); - } - } - - /** - * Create a custom model from a model file. - * Read more about custom model file creation here. - * - * @param modelName the name of the custom model to be created. - * @param modelFilePath the path to model file that exists on the Ollama server. - */ - public void createModel(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException { - String url = this.host + "/api/create"; - String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath); - HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); - HttpClient client = HttpClient.newHttpClient(); - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); - int statusCode = response.statusCode(); - String responseString = response.body(); - if (statusCode != 200) { - throw new OllamaBaseException(statusCode + " - " + responseString); - } - // FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this if the issue is fixed in the Ollama API server. - if (responseString.contains("error")) { - throw new OllamaBaseException(responseString); - } + /** + * Pull a model on the Ollama server from the list of available models. + * + * @param model the name of the model + */ + public void pullModel(String model) + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + String url = this.host + "/api/pull"; + String jsonData = String.format("{\"name\": \"%s\"}", model); + HttpRequest request = + HttpRequest.newBuilder() + .uri(new URI(url)) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)) + .header("Accept", "application/json") + .header("Content-type", "application/json") + .build(); + HttpClient client = HttpClient.newHttpClient(); + HttpResponse response = + client.send(request, HttpResponse.BodyHandlers.ofInputStream()); + int statusCode = response.statusCode(); + InputStream responseBodyStream = response.body(); + String responseString = ""; + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + ModelPullResponse modelPullResponse = + Utils.getObjectMapper().readValue(line, ModelPullResponse.class); if (verbose) { - logger.info(responseString); + logger.info(modelPullResponse.getStatus()); } + } } - - /** - * Delete a model from Ollama server. - * - * @param name the name of the model to be deleted. - * @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama server. - */ - public void deleteModel(String name, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException { - String url = this.host + "/api/delete"; - String jsonData = String.format("{\"name\": \"%s\"}", name); - HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build(); - HttpClient client = HttpClient.newHttpClient(); - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); - int statusCode = response.statusCode(); - String responseBody = response.body(); - if (statusCode == 404 && responseBody.contains("model") && responseBody.contains("not found")) { - return; - } - if (statusCode != 200) { - throw new OllamaBaseException(statusCode + " - " + responseBody); - } + if (statusCode != 200) { + throw new OllamaBaseException(statusCode + " - " + responseString); } + } - /** - * Generate embeddings for a given text from a model - * - * @param model name of model to generate embeddings from - * @param prompt text to generate embeddings for - * @return embeddings - */ - public List generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException { - String url = this.host + "/api/embeddings"; - String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt); - HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); - HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); - int statusCode = response.statusCode(); - String responseBody = response.body(); - if (statusCode == 200) { - EmbeddingResponse embeddingResponse = Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class); - return embeddingResponse.getEmbedding(); + /** + * Gets model details from the Ollama server. + * + * @param modelName the model + * @return the model details + */ + public ModelDetail getModelDetails(String modelName) + throws IOException, OllamaBaseException, InterruptedException { + String url = this.host + "/api/show"; + String jsonData = String.format("{\"name\": \"%s\"}", modelName); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(url)) + .header("Accept", "application/json") + .header("Content-type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(jsonData)) + .build(); + HttpClient client = HttpClient.newHttpClient(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + int statusCode = response.statusCode(); + String responseBody = response.body(); + if (statusCode == 200) { + return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class); + } else { + throw new OllamaBaseException(statusCode + " - " + responseBody); + } + } + + /** + * Create a custom model from a model file. Read more about custom model file creation here. + * + * @param modelName the name of the custom model to be created. + * @param modelFilePath the path to model file that exists on the Ollama server. + */ + public void createModel(String modelName, String modelFilePath) + throws IOException, InterruptedException, OllamaBaseException { + String url = this.host + "/api/create"; + String jsonData = + String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(url)) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) + .build(); + HttpClient client = HttpClient.newHttpClient(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + int statusCode = response.statusCode(); + String responseString = response.body(); + if (statusCode != 200) { + throw new OllamaBaseException(statusCode + " - " + responseString); + } + // FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this + // if the issue is fixed in the Ollama API server. + if (responseString.contains("error")) { + throw new OllamaBaseException(responseString); + } + if (verbose) { + logger.info(responseString); + } + } + + /** + * Delete a model from Ollama server. + * + * @param name the name of the model to be deleted. + * @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama + * server. + */ + public void deleteModel(String name, boolean ignoreIfNotPresent) + throws IOException, InterruptedException, OllamaBaseException { + String url = this.host + "/api/delete"; + String jsonData = String.format("{\"name\": \"%s\"}", name); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(url)) + .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) + .header("Accept", "application/json") + .header("Content-type", "application/json") + .build(); + HttpClient client = HttpClient.newHttpClient(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + int statusCode = response.statusCode(); + String responseBody = response.body(); + if (statusCode == 404 && responseBody.contains("model") && responseBody.contains("not found")) { + return; + } + if (statusCode != 200) { + throw new OllamaBaseException(statusCode + " - " + responseBody); + } + } + + /** + * Generate embeddings for a given text from a model + * + * @param model name of model to generate embeddings from + * @param prompt text to generate embeddings for + * @return embeddings + */ + public List generateEmbeddings(String model, String prompt) + throws IOException, InterruptedException, OllamaBaseException { + String url = this.host + "/api/embeddings"; + String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt); + HttpClient httpClient = HttpClient.newHttpClient(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(url)) + .header("Accept", "application/json") + .header("Content-type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(jsonData)) + .build(); + HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + int statusCode = response.statusCode(); + String responseBody = response.body(); + if (statusCode == 200) { + EmbeddingResponse embeddingResponse = + Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class); + return embeddingResponse.getEmbedding(); + } else { + throw new OllamaBaseException(statusCode + " - " + responseBody); + } + } + + /** + * Ask a question to a model running on Ollama server. This is a sync/blocking call. + * + * @param ollamaModelType the ollama model to ask the question to + * @param promptText the prompt/question text + * @return OllamaResult - that includes response text and time taken for response + */ + public OllamaResult ask(String ollamaModelType, String promptText) + throws OllamaBaseException, IOException, InterruptedException { + OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); + long startTime = System.currentTimeMillis(); + HttpClient httpClient = HttpClient.newHttpClient(); + URI uri = URI.create(this.host + "/api/generate"); + HttpRequest request = + HttpRequest.newBuilder(uri) + .POST( + HttpRequest.BodyPublishers.ofString( + Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) + .header("Content-Type", "application/json") + .build(); + HttpResponse response = + httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + int statusCode = response.statusCode(); + InputStream responseBodyStream = response.body(); + StringBuilder responseBuffer = new StringBuilder(); + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + if (statusCode == 404) { + OllamaErrorResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); + responseBuffer.append(ollamaResponseModel.getError()); } else { - throw new OllamaBaseException(statusCode + " - " + responseBody); + OllamaResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); + if (!ollamaResponseModel.getDone()) { + responseBuffer.append(ollamaResponseModel.getResponse()); + } } + } } + if (statusCode != 200) { + throw new OllamaBaseException(responseBuffer.toString()); + } else { + long endTime = System.currentTimeMillis(); + return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); + } + } - /** - * Ask a question to a model running on Ollama server. This is a sync/blocking call. - * - * @param ollamaModelType the ollama model to ask the question to - * @param promptText the prompt/question text - * @return OllamaResult - that includes response text and time taken for response - */ - public OllamaResult ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); - long startTime = System.currentTimeMillis(); - HttpClient httpClient = HttpClient.newHttpClient(); - URI uri = URI.create(this.host + "/api/generate"); - HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build(); - HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); - int statusCode = response.statusCode(); - InputStream responseBodyStream = response.body(); - StringBuilder responseBuffer = new StringBuilder(); - try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { - String line; - while ((line = reader.readLine()) != null) { - OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); - if (!ollamaResponseModel.getDone()) { - responseBuffer.append(ollamaResponseModel.getResponse()); - } - } - } - if (statusCode != 200) { - throw new OllamaBaseException(statusCode + " - " + responseBuffer); - } else { - long endTime = System.currentTimeMillis(); - return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime); - } - } - - /** - * Ask a question to a model running on Ollama server and get a callback handle that can be used to check for status and get the response from the model later. - * This would be an async/non-blocking call. - * - * @param ollamaModelType the ollama model to ask the question to - * @param promptText the prompt/question text - * @return the ollama async result callback handle - */ - public OllamaAsyncResultCallback askAsync(String ollamaModelType, String promptText) { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); - HttpClient httpClient = HttpClient.newHttpClient(); - URI uri = URI.create(this.host + "/api/generate"); - OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel); - ollamaAsyncResultCallback.start(); - return ollamaAsyncResultCallback; - } + /** + * Ask a question to a model running on Ollama server and get a callback handle that can be used + * to check for status and get the response from the model later. This would be an + * async/non-blocking call. + * + * @param ollamaModelType the ollama model to ask the question to + * @param promptText the prompt/question text + * @return the ollama async result callback handle + */ + public OllamaAsyncResultCallback askAsync(String ollamaModelType, String promptText) { + OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); + HttpClient httpClient = HttpClient.newHttpClient(); + URI uri = URI.create(this.host + "/api/generate"); + OllamaAsyncResultCallback ollamaAsyncResultCallback = + new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel); + ollamaAsyncResultCallback.start(); + return ollamaAsyncResultCallback; + } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResult.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResult.java index 9bdb246..59d5630 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResult.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResult.java @@ -9,11 +9,14 @@ import com.fasterxml.jackson.core.JsonProcessingException; public class OllamaResult { private final String response; + private int httpStatusCode; + private long responseTime = 0; - public OllamaResult(String response, long responseTime) { + public OllamaResult(String response, long responseTime, int httpStatusCode) { this.response = response; this.responseTime = responseTime; + this.httpStatusCode = httpStatusCode; } /** @@ -34,6 +37,15 @@ public class OllamaResult { return responseTime; } + /** + * Get the response status code. + * + * @return int - response status code + */ + public int getHttpStatusCode() { + return httpStatusCode; + } + @Override public String toString() { try { From d52427fb6866ec9242f51e6dc4390910f52901de Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Thu, 14 Dec 2023 16:47:38 +0530 Subject: [PATCH 7/7] Fixes to tests --- .../ollama4j/TestMockedAPIs.java | 186 +++++++++--------- 1 file changed, 93 insertions(+), 93 deletions(-) diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java index f9a2f09..6d43cbf 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java @@ -1,121 +1,121 @@ package io.github.amithkoujalgi.ollama4j; +import static org.mockito.Mockito.*; + import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback; import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; -import org.junit.jupiter.api.Test; -import org.mockito.Mockito; - import java.io.IOException; import java.net.URISyntaxException; import java.util.ArrayList; - -import static org.mockito.Mockito.*; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; public class TestMockedAPIs { - @Test - public void testMockPullModel() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - try { - doNothing().when(ollamaAPI).pullModel(model); - ollamaAPI.pullModel(model); - verify(ollamaAPI, times(1)).pullModel(model); - } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - throw new RuntimeException(e); - } + @Test + public void testMockPullModel() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + try { + doNothing().when(ollamaAPI).pullModel(model); + ollamaAPI.pullModel(model); + verify(ollamaAPI, times(1)).pullModel(model); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + throw new RuntimeException(e); } + } - @Test - public void testListModels() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - try { - when(ollamaAPI.listModels()).thenReturn(new ArrayList<>()); - ollamaAPI.listModels(); - verify(ollamaAPI, times(1)).listModels(); - } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - throw new RuntimeException(e); - } + @Test + public void testListModels() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + try { + when(ollamaAPI.listModels()).thenReturn(new ArrayList<>()); + ollamaAPI.listModels(); + verify(ollamaAPI, times(1)).listModels(); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + throw new RuntimeException(e); } + } - @Test - public void testCreateModel() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - String modelFilePath = "/somemodel"; - try { - doNothing().when(ollamaAPI).createModel(model, modelFilePath); - ollamaAPI.createModel(model, modelFilePath); - verify(ollamaAPI, times(1)).createModel(model, modelFilePath); - } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); - } + @Test + public void testCreateModel() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + String modelFilePath = "/somemodel"; + try { + doNothing().when(ollamaAPI).createModel(model, modelFilePath); + ollamaAPI.createModel(model, modelFilePath); + verify(ollamaAPI, times(1)).createModel(model, modelFilePath); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); } + } - @Test - public void testDeleteModel() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - try { - doNothing().when(ollamaAPI).deleteModel(model, true); - ollamaAPI.deleteModel(model, true); - verify(ollamaAPI, times(1)).deleteModel(model, true); - } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); - } + @Test + public void testDeleteModel() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + try { + doNothing().when(ollamaAPI).deleteModel(model, true); + ollamaAPI.deleteModel(model, true); + verify(ollamaAPI, times(1)).deleteModel(model, true); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); } + } - @Test - public void testGetModelDetails() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - try { - when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail()); - ollamaAPI.getModelDetails(model); - verify(ollamaAPI, times(1)).getModelDetails(model); - } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); - } + @Test + public void testGetModelDetails() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + try { + when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail()); + ollamaAPI.getModelDetails(model); + verify(ollamaAPI, times(1)).getModelDetails(model); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); } + } - @Test - public void testGenerateEmbeddings() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - String prompt = "some prompt text"; - try { - when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>()); - ollamaAPI.generateEmbeddings(model, prompt); - verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt); - } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); - } + @Test + public void testGenerateEmbeddings() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + String prompt = "some prompt text"; + try { + when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>()); + ollamaAPI.generateEmbeddings(model, prompt); + verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); } + } - @Test - public void testAsk() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - String prompt = "some prompt text"; - try { - when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0)); - ollamaAPI.ask(model, prompt); - verify(ollamaAPI, times(1)).ask(model, prompt); - } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); - } + @Test + public void testAsk() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + String prompt = "some prompt text"; + try { + when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0, 200)); + ollamaAPI.ask(model, prompt); + verify(ollamaAPI, times(1)).ask(model, prompt); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); } + } - @Test - public void testAskAsync() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - String prompt = "some prompt text"; - when(ollamaAPI.askAsync(model, prompt)).thenReturn(new OllamaAsyncResultCallback(null, null, null)); - ollamaAPI.askAsync(model, prompt); - verify(ollamaAPI, times(1)).askAsync(model, prompt); - } + @Test + public void testAskAsync() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + String prompt = "some prompt text"; + when(ollamaAPI.askAsync(model, prompt)) + .thenReturn(new OllamaAsyncResultCallback(null, null, null)); + ollamaAPI.askAsync(model, prompt); + verify(ollamaAPI, times(1)).askAsync(model, prompt); + } }