From 1fdd89f50f4d92c4adb7a60e4c0d7d5aa34d1f6f Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Tue, 19 Dec 2023 00:32:58 +0530 Subject: [PATCH] clean up --- README.md | 238 ++++++++--------- .../ollama4j/core/OllamaAPI.java | 247 +++++++++++------- .../ollama4j/core/models/ModelDetail.java | 6 + .../CustomModelFileContentsRequest.java | 23 ++ .../request/CustomModelFilePathRequest.java | 23 ++ .../request/ModelEmbeddingsRequest.java | 23 ++ .../core/models/request/ModelRequest.java | 22 ++ .../ollama4j/unittests/TestMockedAPIs.java | 8 +- 8 files changed, 376 insertions(+), 214 deletions(-) create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFileContentsRequest.java create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFilePathRequest.java create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelRequest.java diff --git a/README.md b/README.md index a561b55..3ac83c6 100644 --- a/README.md +++ b/README.md @@ -44,13 +44,13 @@ for [Ollama](https://github.com/jmorganca/ollama/blob/main/docs/api.md) APIs. [![][ollama-shield]][ollama] Or [![][ollama-docker-shield]][ollama-docker] [ollama]: https://ollama.ai/ + [ollama-shield]: https://img.shields.io/badge/Ollama-Local_Installation-blue.svg?style=for-the-badge&labelColor=gray [ollama-docker]: https://hub.docker.com/r/ollama/ollama + [ollama-docker-shield]: https://img.shields.io/badge/Ollama-Docker-blue.svg?style=for-the-badge&labelColor=gray - - #### Installation In your Maven project, add this dependency available in @@ -59,9 +59,9 @@ the [Central Repository](https://s01.oss.sonatype.org/#nexus-search;quick~ollama ```xml - io.github.amithkoujalgi - ollama4j - 1.0-SNAPSHOT + io.github.amithkoujalgi + ollama4j + 1.0-SNAPSHOT ``` @@ -71,10 +71,10 @@ your `pom.xml`: ```xml - - ollama4j-from-ossrh - https://s01.oss.sonatype.org/content/repositories/snapshots - + + ollama4j-from-ossrh + https://s01.oss.sonatype.org/content/repositories/snapshots + ``` @@ -113,13 +113,13 @@ Instantiate `OllamaAPI` ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); - // set verbose - true/false - ollamaAPI.setVerbose(true); - } + // set verbose - true/false + ollamaAPI.setVerbose(true); + } } ``` @@ -128,11 +128,11 @@ public class Main { ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - ollamaAPI.pullModel(OllamaModelType.LLAMA2); - } + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ollamaAPI.pullModel(OllamaModelType.LLAMA2); + } } ``` @@ -143,12 +143,12 @@ _Find the list of available models from Ollama [here](https://ollama.ai/library) ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - List models = ollamaAPI.listModels(); - models.forEach(model -> System.out.println(model.getName())); - } + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + List models = ollamaAPI.listModels(); + models.forEach(model -> System.out.println(model.getName())); + } } ``` @@ -164,12 +164,12 @@ sqlcoder:latest ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - ModelDetail modelDetails = ollamaAPI.getModelDetails(OllamaModelType.LLAMA2); - System.out.println(modelDetails); - } + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ModelDetail modelDetails = ollamaAPI.getModelDetails(OllamaModelType.LLAMA2); + System.out.println(modelDetails); + } } ``` @@ -189,11 +189,11 @@ Response: ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - ollamaAPI.createModel("mycustommodel", "/path/to/modelfile/on/ollama-server"); - } + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ollamaAPI.createModel("mycustommodel", "/path/to/modelfile/on/ollama-server"); + } } ``` @@ -202,12 +202,12 @@ public class Main { ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - ollamaAPI.setVerbose(false); - ollamaAPI.deleteModel("mycustommodel", true); - } + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ollamaAPI.setVerbose(false); + ollamaAPI.deleteModel("mycustommodel", true); + } } ``` @@ -216,13 +216,13 @@ public class Main { ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - List embeddings = ollamaAPI.generateEmbeddings(OllamaModelType.LLAMA2, - "Here is an article about llamas..."); - embeddings.forEach(System.out::println); - } + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + List embeddings = ollamaAPI.generateEmbeddings(OllamaModelType.LLAMA2, + "Here is an article about llamas..."); + embeddings.forEach(System.out::println); + } } ``` @@ -233,12 +233,12 @@ public class Main { ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - String response = ollamaAPI.ask(OllamaModelType.LLAMA2, "Who are you?"); - System.out.println(response); - } + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + String response = ollamaAPI.ask(OllamaModelType.LLAMA2, "Who are you?"); + System.out.println(response); + } } ``` @@ -247,20 +247,20 @@ public class Main { ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - OllamaAsyncResultCallback ollamaAsyncResultCallback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, - "Who are you?"); - while (true) { - if (ollamaAsyncResultCallback.isComplete()) { - System.out.println(ollamaAsyncResultCallback.getResponse()); - break; - } - // introduce sleep to check for status with a time interval - // Thread.sleep(1000); + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + OllamaAsyncResultCallback ollamaAsyncResultCallback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, + "Who are you?"); + while (true) { + if (ollamaAsyncResultCallback.isComplete()) { + System.out.println(ollamaAsyncResultCallback.getResponse()); + break; + } + // introduce sleep to check for status with a time interval + // Thread.sleep(1000); + } } - } } ``` @@ -280,14 +280,14 @@ You'd then get a response from the model: ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); - String prompt = "List all cricket world cup teams of 2019."; - String response = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt); - System.out.println(response); - } + String prompt = "List all cricket world cup teams of 2019."; + String response = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt); + System.out.println(response); + } } ``` @@ -316,15 +316,15 @@ You'd then get a response from the model: ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); - String prompt = SamplePrompts.getSampleDatabasePromptWithQuestion( - "List all customer names who have bought one or more products"); - String response = ollamaAPI.ask(OllamaModelType.SQLCODER, prompt); - System.out.println(response); - } + String prompt = SamplePrompts.getSampleDatabasePromptWithQuestion( + "List all customer names who have bought one or more products"); + String response = ollamaAPI.ask(OllamaModelType.SQLCODER, prompt); + System.out.println(response); + } } ``` @@ -351,17 +351,17 @@ With Files: ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - ollamaAPI.setRequestTimeoutSeconds(10); + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ollamaAPI.setRequestTimeoutSeconds(10); - OllamaResult response = ollamaAPI.askWithImageFiles(OllamaModelType.LLAVA, - "What's in this image?", - List.of( - new File("/path/to/image"))); - System.out.println(response); - } + OllamaResult response = ollamaAPI.askWithImageFiles(OllamaModelType.LLAVA, + "What's in this image?", + List.of( + new File("/path/to/image"))); + System.out.println(response); + } } ``` @@ -370,17 +370,17 @@ With URLs: ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - ollamaAPI.setRequestTimeoutSeconds(10); + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ollamaAPI.setRequestTimeoutSeconds(10); - OllamaResult response = ollamaAPI.askWithImageURLs(OllamaModelType.LLAVA, - "What's in this image?", - List.of( - "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")); - System.out.println(response); - } + OllamaResult response = ollamaAPI.askWithImageURLs(OllamaModelType.LLAVA, + "What's in this image?", + List.of( + "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")); + System.out.println(response); + } } ``` @@ -398,21 +398,21 @@ The dog seems to be enjoying its time outdoors, perhaps on a lake. @SuppressWarnings("ALL") public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); - String prompt = "List all cricket world cup teams of 2019."; - OllamaAsyncResultCallback callback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, prompt); - while (!callback.isComplete() || !callback.getStream().isEmpty()) { - // poll for data from the response stream - String response = callback.getStream().poll(); - if (response != null) { - System.out.print(response); - } - Thread.sleep(1000); + String prompt = "List all cricket world cup teams of 2019."; + OllamaAsyncResultCallback callback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, prompt); + while (!callback.isComplete() || !callback.getStream().isEmpty()) { + // poll for data from the response stream + String response = callback.getStream().poll(); + if (response != null) { + System.out.print(response); + } + Thread.sleep(1000); + } } - } } ``` @@ -452,8 +452,8 @@ make it - [x] Fix deprecated HTTP client code - [x] Setup logging - [x] Use lombok -- [ ] Update request body creation with Java objects -- [ ] Async APIs for images +- [x] Update request body creation with Java objects +- [ ] Async APIs for images - [ ] Add additional params for `ask` APIs such as: - `options`: additional model parameters for the Modelfile such as `temperature` - `system`: system prompt to (overrides what is defined in the Modelfile) diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 066460f..a64e7ef 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -2,6 +2,10 @@ package io.github.amithkoujalgi.ollama4j.core; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.*; +import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest; +import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest; +import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest; +import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import java.io.BufferedReader; import java.io.ByteArrayOutputStream; @@ -17,7 +21,6 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.charset.StandardCharsets; import java.nio.file.Files; -import java.nio.file.Path; import java.time.Duration; import java.util.ArrayList; import java.util.Base64; @@ -25,9 +28,7 @@ import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * The base Ollama API class. - */ +/** The base Ollama API class. */ @SuppressWarnings("DuplicatedCode") public class OllamaAPI { @@ -71,15 +72,21 @@ public class OllamaAPI { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = this.host + "/api/tags"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = HttpRequest.newBuilder().uri(new URI(url)) - .header("Accept", "application/json").header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)).GET().build(); - HttpResponse response = httpClient.send(httpRequest, - HttpResponse.BodyHandlers.ofString()); + HttpRequest httpRequest = + HttpRequest.newBuilder() + .uri(new URI(url)) + .header("Accept", "application/json") + .header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) + .GET() + .build(); + HttpResponse response = + httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); if (statusCode == 200) { - return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class) + return Utils.getObjectMapper() + .readValue(responseString, ListModelsResponse.class) .getModels(); } else { throw new OllamaBaseException(statusCode + " - " + responseString); @@ -90,28 +97,32 @@ public class OllamaAPI { * Pull a model on the Ollama server from the list of available models. * - * @param model the name of the model + * @param modelName the name of the model */ - public void pullModel(String model) + public void pullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String url = this.host + "/api/pull"; - String jsonData = String.format("{\"name\": \"%s\"}", model); - HttpRequest request = HttpRequest.newBuilder().uri(new URI(url)) - .POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json") - .header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)).build(); + String jsonData = new ModelRequest(modelName).toString(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(new URI(url)) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)) + .header("Accept", "application/json") + .header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) + .build(); HttpClient client = HttpClient.newHttpClient(); - HttpResponse response = client.send(request, - HttpResponse.BodyHandlers.ofInputStream()); + HttpResponse response = + client.send(request, HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); InputStream responseBodyStream = response.body(); String responseString = ""; - try (BufferedReader reader = new BufferedReader( - new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { - ModelPullResponse modelPullResponse = Utils.getObjectMapper() - .readValue(line, ModelPullResponse.class); + ModelPullResponse modelPullResponse = + Utils.getObjectMapper().readValue(line, ModelPullResponse.class); if (verbose) { logger.info(modelPullResponse.getStatus()); } @@ -131,11 +142,15 @@ public class OllamaAPI { public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException { String url = this.host + "/api/show"; - String jsonData = String.format("{\"name\": \"%s\"}", modelName); - HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)) - .header("Accept", "application/json").header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + String jsonData = new ModelRequest(modelName).toString(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(url)) + .header("Accept", "application/json") + .header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)) + .build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -151,18 +166,21 @@ public class OllamaAPI { * Create a custom model from a model file. Read more about custom model file creation here. * - * @param modelName the name of the custom model to be created. + * @param modelName the name of the custom model to be created. * @param modelFilePath the path to model file that exists on the Ollama server. */ - public void createModel(String modelName, String modelFilePath) + public void createModelWithFilePath(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException { String url = this.host + "/api/create"; - String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, - modelFilePath); - HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)) - .header("Accept", "application/json").header("Content-Type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); + String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(url)) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) + .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) + .build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -180,21 +198,59 @@ public class OllamaAPI { } } + /** + * Create a custom model from a model file. Read more about custom model file creation here. + * + * @param modelName the name of the custom model to be created. + * @param modelFileContents the path to model file that exists on the Ollama server. + */ + public void createModelWithModelFileContents(String modelName, String modelFileContents) + throws IOException, InterruptedException, OllamaBaseException { + String url = this.host + "/api/create"; + String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(url)) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) + .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) + .build(); + HttpClient client = HttpClient.newHttpClient(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + int statusCode = response.statusCode(); + String responseString = response.body(); + if (statusCode != 200) { + throw new OllamaBaseException(statusCode + " - " + responseString); + } + if (responseString.contains("error")) { + throw new OllamaBaseException(responseString); + } + if (verbose) { + logger.info(responseString); + } + } + /** * Delete a model from Ollama server. * - * @param name the name of the model to be deleted. + * @param modelName the name of the model to be deleted. * @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama - * server. + * server. */ - public void deleteModel(String name, boolean ignoreIfNotPresent) + public void deleteModel(String modelName, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException { String url = this.host + "/api/delete"; - String jsonData = String.format("{\"name\": \"%s\"}", name); - HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)) - .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) - .header("Accept", "application/json").header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)).build(); + String jsonData = new ModelRequest(modelName).toString(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(url)) + .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) + .header("Accept", "application/json") + .header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) + .build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -210,25 +266,29 @@ public class OllamaAPI { /** * Generate embeddings for a given text from a model * - * @param model name of model to generate embeddings from + * @param model name of model to generate embeddings from * @param prompt text to generate embeddings for * @return embeddings */ public List generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException { String url = this.host + "/api/embeddings"; - String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt); + String jsonData = new ModelEmbeddingsRequest(model, prompt).toString(); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)) - .header("Accept", "application/json").header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + HttpRequest request = + HttpRequest.newBuilder() + .uri(URI.create(url)) + .header("Accept", "application/json") + .header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)) + .build(); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseBody = response.body(); if (statusCode == 200) { - EmbeddingResponse embeddingResponse = Utils.getObjectMapper() - .readValue(responseBody, EmbeddingResponse.class); + EmbeddingResponse embeddingResponse = + Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class); return embeddingResponse.getEmbedding(); } else { throw new OllamaBaseException(statusCode + " - " + responseBody); @@ -238,7 +298,7 @@ public class OllamaAPI { /** * Ask a question to a model running on Ollama server. This is a sync/blocking call. * - * @param model the ollama model to ask the question to + * @param model the ollama model to ask the question to * @param promptText the prompt/question text * @return OllamaResult - that includes response text and time taken for response */ @@ -248,11 +308,30 @@ public class OllamaAPI { return askSync(ollamaRequestModel); } + /** + * Ask a question to a model running on Ollama server and get a callback handle that can be used + * to check for status and get the response from the model later. This would be an + * async/non-blocking call. + * + * @param model the ollama model to ask the question to + * @param promptText the prompt/question text + * @return the ollama async result callback handle + */ + public OllamaAsyncResultCallback askAsync(String model, String promptText) { + OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText); + HttpClient httpClient = HttpClient.newHttpClient(); + URI uri = URI.create(this.host + "/api/generate"); + OllamaAsyncResultCallback ollamaAsyncResultCallback = + new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel, requestTimeoutSeconds); + ollamaAsyncResultCallback.start(); + return ollamaAsyncResultCallback; + } + /** * With one or more image files, ask a question to a model running on Ollama server. This is a * sync/blocking call. * - * @param model the ollama model to ask the question to + * @param model the ollama model to ask the question to * @param promptText the prompt/question text * @param imageFiles the list of image files to use for the question * @return OllamaResult - that includes response text and time taken for response @@ -271,9 +350,9 @@ public class OllamaAPI { * With one or more image URLs, ask a question to a model running on Ollama server. This is a * sync/blocking call. * - * @param model the ollama model to ask the question to + * @param model the ollama model to ask the question to * @param promptText the prompt/question text - * @param imageURLs the list of image URLs to use for the question + * @param imageURLs the list of image URLs to use for the question * @return OllamaResult - that includes response text and time taken for response */ public OllamaResult askWithImageURLs(String model, String promptText, List imageURLs) @@ -286,18 +365,19 @@ public class OllamaAPI { return askSync(ollamaRequestModel); } - public static String encodeFileToBase64(File file) throws IOException { + private static String encodeFileToBase64(File file) throws IOException { return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath())); } - public static String encodeByteArrayToBase64(byte[] bytes) { + private static String encodeByteArrayToBase64(byte[] bytes) { return Base64.getEncoder().encodeToString(bytes); } - public static byte[] loadImageBytesFromUrl(String imageUrl) + private static byte[] loadImageBytesFromUrl(String imageUrl) throws IOException, URISyntaxException { URL url = new URI(imageUrl).toURL(); - try (InputStream in = url.openStream(); ByteArrayOutputStream out = new ByteArrayOutputStream()) { + try (InputStream in = url.openStream(); + ByteArrayOutputStream out = new ByteArrayOutputStream()) { byte[] buffer = new byte[1024]; int bytesRead; while ((bytesRead = in.read(buffer)) != -1) { @@ -307,50 +387,35 @@ public class OllamaAPI { } } - /** - * Ask a question to a model running on Ollama server and get a callback handle that can be used - * to check for status and get the response from the model later. This would be an - * async/non-blocking call. - * - * @param model the ollama model to ask the question to - * @param promptText the prompt/question text - * @return the ollama async result callback handle - */ - public OllamaAsyncResultCallback askAsync(String model, String promptText) { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText); - HttpClient httpClient = HttpClient.newHttpClient(); - URI uri = URI.create(this.host + "/api/generate"); - OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(httpClient, - uri, ollamaRequestModel, requestTimeoutSeconds); - ollamaAsyncResultCallback.start(); - return ollamaAsyncResultCallback; - } - private OllamaResult askSync(OllamaRequestModel ollamaRequestModel) throws OllamaBaseException, IOException, InterruptedException { long startTime = System.currentTimeMillis(); HttpClient httpClient = HttpClient.newHttpClient(); URI uri = URI.create(this.host + "/api/generate"); - HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString( - Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) - .header("Content-Type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)).build(); - HttpResponse response = httpClient.send(request, - HttpResponse.BodyHandlers.ofInputStream()); + HttpRequest request = + HttpRequest.newBuilder(uri) + .POST( + HttpRequest.BodyPublishers.ofString( + Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) + .header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) + .build(); + HttpResponse response = + httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); InputStream responseBodyStream = response.body(); StringBuilder responseBuffer = new StringBuilder(); - try (BufferedReader reader = new BufferedReader( - new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { if (statusCode == 404) { - OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper() - .readValue(line, OllamaErrorResponseModel.class); + OllamaErrorResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); responseBuffer.append(ollamaResponseModel.getError()); } else { - OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper() - .readValue(line, OllamaResponseModel.class); + OllamaResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); if (!ollamaResponseModel.isDone()) { responseBuffer.append(ollamaResponseModel.getResponse()); } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java index 20cf85e..fa557c6 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java @@ -1,14 +1,20 @@ package io.github.amithkoujalgi.ollama4j.core.models; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Map; import lombok.Data; @Data +@JsonIgnoreProperties(ignoreUnknown = true) public class ModelDetail { private String license; + @JsonProperty("modelfile") private String modelFile; + private String parameters; private String template; private String system; + private Map details; } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFileContentsRequest.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFileContentsRequest.java new file mode 100644 index 0000000..9e606d3 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFileContentsRequest.java @@ -0,0 +1,23 @@ +package io.github.amithkoujalgi.ollama4j.core.models.request; + +import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; + +import com.fasterxml.jackson.core.JsonProcessingException; +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class CustomModelFileContentsRequest { + private String name; + private String modelfile; + + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFilePathRequest.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFilePathRequest.java new file mode 100644 index 0000000..ea08dbf --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFilePathRequest.java @@ -0,0 +1,23 @@ +package io.github.amithkoujalgi.ollama4j.core.models.request; + +import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; + +import com.fasterxml.jackson.core.JsonProcessingException; +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class CustomModelFilePathRequest { + private String name; + private String path; + + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java new file mode 100644 index 0000000..1455a94 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java @@ -0,0 +1,23 @@ +package io.github.amithkoujalgi.ollama4j.core.models.request; + +import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; + +import com.fasterxml.jackson.core.JsonProcessingException; +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class ModelEmbeddingsRequest { + private String model; + private String prompt; + + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelRequest.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelRequest.java new file mode 100644 index 0000000..d3fdec4 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelRequest.java @@ -0,0 +1,22 @@ +package io.github.amithkoujalgi.ollama4j.core.models.request; + +import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; + +import com.fasterxml.jackson.core.JsonProcessingException; +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class ModelRequest { + private String name; + + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java index 3b5fafc..7c46e37 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java @@ -44,11 +44,11 @@ class TestMockedAPIs { void testCreateModel() { OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); String model = OllamaModelType.LLAMA2; - String modelFilePath = "/somemodel"; + String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros."; try { - doNothing().when(ollamaAPI).createModel(model, modelFilePath); - ollamaAPI.createModel(model, modelFilePath); - verify(ollamaAPI, times(1)).createModel(model, modelFilePath); + doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath); + ollamaAPI.createModelWithModelFileContents(model, modelFilePath); + verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath); } catch (IOException | OllamaBaseException | InterruptedException e) { throw new RuntimeException(e); }