diff --git a/README.md b/README.md index fc11d0c..b20a11e 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,8 @@ ollama4j-icon -A Java library (wrapper/binding) for [Ollama](https://github.com/jmorganca/ollama/blob/main/docs/api.md) APIs. +A Java library (wrapper/binding) +for [Ollama](https://github.com/jmorganca/ollama/blob/main/docs/api.md) APIs. ```mermaid flowchart LR @@ -38,7 +39,8 @@ A Java library (wrapper/binding) for [Ollama](https://github.com/jmorganca/ollam #### Requirements -- Ollama (Either [natively](https://ollama.ai/download) setup or via [Docker](https://hub.docker.com/r/ollama/ollama)) +- Ollama (Either [natively](https://ollama.ai/download) setup or + via [Docker](https://hub.docker.com/r/ollama/ollama)) - Java 11 or above #### Installation @@ -49,21 +51,22 @@ the [Central Repository](https://s01.oss.sonatype.org/#nexus-search;quick~ollama ```xml - io.github.amithkoujalgi - ollama4j - 1.0-SNAPSHOT + io.github.amithkoujalgi + ollama4j + 1.0-SNAPSHOT ``` -You might want to include the Maven repository to pull the ollama4j library from. Include this in your `pom.xml`: +You might want to include the Maven repository to pull the ollama4j library from. Include this in +your `pom.xml`: ```xml - - ollama4j-from-ossrh - https://s01.oss.sonatype.org/content/repositories/snapshots - + + ollama4j-from-ossrh + https://s01.oss.sonatype.org/content/repositories/snapshots + ``` @@ -104,13 +107,14 @@ Instantiate `OllamaAPI` ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - // set verbose - true/false - ollamaAPI.setVerbose(true); - } + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + + // set verbose - true/false + ollamaAPI.setVerbose(true); + } } ``` @@ -118,11 +122,12 @@ public class Main { ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - ollamaAPI.pullModel(OllamaModelType.LLAMA2); - } + + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ollamaAPI.pullModel(OllamaModelType.LLAMA2); + } } ``` @@ -132,12 +137,13 @@ _Find the list of available models from Ollama [here](https://ollama.ai/library) ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - List models = ollamaAPI.listModels(); - models.forEach(model -> System.out.println(model.getName())); - } + + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + List models = ollamaAPI.listModels(); + models.forEach(model -> System.out.println(model.getName())); + } } ``` @@ -152,12 +158,13 @@ sqlcoder:latest ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - ModelDetail modelDetails = ollamaAPI.getModelDetails(OllamaModelType.LLAMA2); - System.out.println(modelDetails); - } + + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ModelDetail modelDetails = ollamaAPI.getModelDetails(OllamaModelType.LLAMA2); + System.out.println(modelDetails); + } } ``` @@ -176,11 +183,12 @@ Response: ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - ollamaAPI.createModel("mycustommodel", "/path/to/modelfile/on/ollama-server"); - } + + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ollamaAPI.createModel("mycustommodel", "/path/to/modelfile/on/ollama-server"); + } } ``` @@ -188,12 +196,13 @@ public class Main { ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - ollamaAPI.setVerbose(false); - ollamaAPI.deleteModel("mycustommodel", true); - } + + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ollamaAPI.setVerbose(false); + ollamaAPI.deleteModel("mycustommodel", true); + } } ``` @@ -201,12 +210,14 @@ public class Main { ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - List embeddings = ollamaAPI.generateEmbeddings(OllamaModelType.LLAMA2, "Here is an article about llamas..."); - embeddings.forEach(System.out::println); - } + + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + List embeddings = ollamaAPI.generateEmbeddings(OllamaModelType.LLAMA2, + "Here is an article about llamas..."); + embeddings.forEach(System.out::println); + } } ``` @@ -216,12 +227,13 @@ public class Main { ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - String response = ollamaAPI.ask(OllamaModelType.LLAMA2, "Who are you?"); - System.out.println(response); - } + + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + String response = ollamaAPI.ask(OllamaModelType.LLAMA2, "Who are you?"); + System.out.println(response); + } } ``` @@ -229,27 +241,32 @@ public class Main { ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - OllamaAsyncResultCallback ollamaAsyncResultCallback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, "Who are you?"); - while (true) { - if (ollamaAsyncResultCallback.isComplete()) { - System.out.println(ollamaAsyncResultCallback.getResponse()); - break; - } - // introduce sleep to check for status with a time interval - // Thread.sleep(1000); - } + + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + OllamaAsyncResultCallback ollamaAsyncResultCallback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, + "Who are you?"); + while (true) { + if (ollamaAsyncResultCallback.isComplete()) { + System.out.println(ollamaAsyncResultCallback.getResponse()); + break; + } + // introduce sleep to check for status with a time interval + // Thread.sleep(1000); } + } } ``` You'd then get a response from the model: -> I am LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in a conversational -> manner. I am trained on a massive dataset of text from the internet and can generate human-like responses to a wide -> range of topics and questions. I can be used to create chatbots, virtual assistants, and other applications that +> I am LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in +> a conversational +> manner. I am trained on a massive dataset of text from the internet and can generate human-like +> responses to a wide +> range of topics and questions. I can be used to create chatbots, virtual assistants, and other +> applications that > require > natural language understanding and generation capabilities. @@ -257,20 +274,22 @@ You'd then get a response from the model: ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - String prompt = "List all cricket world cup teams of 2019."; - String response = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt); - System.out.println(response); - } + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + + String prompt = "List all cricket world cup teams of 2019."; + String response = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt); + System.out.println(response); + } } ``` You'd then get a response from the model: -> The 2019 ICC Cricket World Cup was held in England and Wales from May 30 to July 14, 2019. The following teams +> The 2019 ICC Cricket World Cup was held in England and Wales from May 30 to July 14, 2019. The +> following teams > participated in the tournament: > > 1. Australia @@ -283,21 +302,24 @@ You'd then get a response from the model: > 8. Sri Lanka > 9. West Indies > -> These teams competed in a round-robin format, with the top four teams advancing to the semi-finals. The tournament was +> These teams competed in a round-robin format, with the top four teams advancing to the +> semi-finals. The tournament was > won by the England cricket team, who defeated New Zealand in the final. #### Try asking for a Database query for your data schema: ```java public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - String prompt = SamplePrompts.getSampleDatabasePromptWithQuestion("List all customer names who have bought one or more products"); - String response = ollamaAPI.ask(OllamaModelType.SQLCODER, prompt); - System.out.println(response); - } + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + + String prompt = SamplePrompts.getSampleDatabasePromptWithQuestion( + "List all customer names who have bought one or more products"); + String response = ollamaAPI.ask(OllamaModelType.SQLCODER, prompt); + System.out.println(response); + } } ``` @@ -315,27 +337,75 @@ FROM sales GROUP BY customers.name; ``` +#### Try asking some questions by passing images 🖼️: + +With Files: + +```java +public class Main { + + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ollamaAPI.setRequestTimeoutSeconds(10); + + OllamaResult response = ollamaAPI.askWithImageURLs(OllamaModelType.LLAVA, + "What's in this image?", + List.of( + "/path/to/image")); + System.out.println(response); + } +} +``` + +With URLs: + +```java +public class Main { + + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ollamaAPI.setRequestTimeoutSeconds(10); + + OllamaResult response = ollamaAPI.askWithImageURLs(OllamaModelType.LLAVA, + "What's in this image?", + List.of( + "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")); + System.out.println(response); + } +} +``` + +You'd then get a response from the model: + +``` +This image features a white boat with brown cushions, where a dog is sitting on the back of the boat. +The dog seems to be enjoying its time outdoors, perhaps on a lake. +``` + #### Async API with streaming response ```java @SuppressWarnings("ALL") public class Main { - public static void main(String[] args) { - String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); - String prompt = "List all cricket world cup teams of 2019."; - OllamaAsyncResultCallback callback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, prompt); - while (!callback.isComplete() || !callback.getStream().isEmpty()) { - // poll for data from the response stream - String response = callback.getStream().poll(); - if (response != null) { - System.out.print(response); - } - Thread.sleep(1000); - } + public static void main(String[] args) { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + + String prompt = "List all cricket world cup teams of 2019."; + OllamaAsyncResultCallback callback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, prompt); + while (!callback.isComplete() || !callback.getStream().isEmpty()) { + // poll for data from the response stream + String response = callback.getStream().poll(); + if (response != null) { + System.out.print(response); + } + Thread.sleep(1000); } + } } ``` @@ -369,14 +439,16 @@ make it ### Areas of improvement -- [x] Use Java-naming conventions for attributes in the request/response models instead of the snake-case conventions. ( +- [x] Use Java-naming conventions for attributes in the request/response models instead of the + snake-case conventions. ( possibly with Jackson-mapper's `@JsonProperty`) - [x] Fix deprecated HTTP client code - [ ] Add additional params for `ask` APIs such as: - `options`: additional model parameters for the Modelfile such as `temperature` - `system`: system prompt to (overrides what is defined in the Modelfile) - `template`: the full prompt or prompt template (overrides what is defined in the Modelfile) - - `context`: the context parameter returned from a previous request, which can be used to keep a short + - `context`: the context parameter returned from a previous request, which can be used to keep a + short conversational memory - `stream`: Add support for streaming responses from the model - [x] Setup logging @@ -386,9 +458,11 @@ make it ### Get Involved -Contributions are most welcome! Whether it's reporting a bug, proposing an enhancement, or helping with code - any sort +Contributions are most welcome! Whether it's reporting a bug, proposing an enhancement, or helping +with code - any sort of contribution is much appreciated. ### Credits -The nomenclature and the icon have been adopted from the incredible [Ollama](https://ollama.ai/) project. +The nomenclature and the icon have been adopted from the incredible [Ollama](https://ollama.ai/) +project. diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index dc5dc65..83727f2 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -4,28 +4,37 @@ import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.*; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import java.io.BufferedReader; +import java.io.ByteArrayOutputStream; +import java.io.File; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.net.URI; import java.net.URISyntaxException; +import java.net.URL; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; import java.time.Duration; +import java.util.ArrayList; +import java.util.Base64; import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** The base Ollama API class. */ +/** + * The base Ollama API class. + */ @SuppressWarnings("DuplicatedCode") public class OllamaAPI { private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); private final String host; - private final long requestTimeoutSeconds = 3; - private boolean verbose = false; + private long requestTimeoutSeconds = 3; + private boolean verbose = true; /** * Instantiates the Ollama API. @@ -40,6 +49,10 @@ public class OllamaAPI { } } + public void setRequestTimeoutSeconds(long requestTimeoutSeconds) { + this.requestTimeoutSeconds = requestTimeoutSeconds; + } + /** * Set/unset logging of responses * @@ -58,21 +71,15 @@ public class OllamaAPI { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = this.host + "/api/tags"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = - HttpRequest.newBuilder() - .uri(new URI(url)) - .header("Accept", "application/json") - .header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .GET() - .build(); - HttpResponse response = - httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); + HttpRequest httpRequest = HttpRequest.newBuilder().uri(new URI(url)) + .header("Accept", "application/json").header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)).GET().build(); + HttpResponse response = httpClient.send(httpRequest, + HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); if (statusCode == 200) { - return Utils.getObjectMapper() - .readValue(responseString, ListModelsResponse.class) + return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class) .getModels(); } else { throw new OllamaBaseException(statusCode + " - " + responseString); @@ -89,26 +96,22 @@ public class OllamaAPI { throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String url = this.host + "/api/pull"; String jsonData = String.format("{\"name\": \"%s\"}", model); - HttpRequest request = - HttpRequest.newBuilder() - .uri(new URI(url)) - .POST(HttpRequest.BodyPublishers.ofString(jsonData)) - .header("Accept", "application/json") - .header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .build(); + HttpRequest request = HttpRequest.newBuilder().uri(new URI(url)) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json") + .header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)).build(); HttpClient client = HttpClient.newHttpClient(); - HttpResponse response = - client.send(request, HttpResponse.BodyHandlers.ofInputStream()); + HttpResponse response = client.send(request, + HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); InputStream responseBodyStream = response.body(); String responseString = ""; - try (BufferedReader reader = - new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { - ModelPullResponse modelPullResponse = - Utils.getObjectMapper().readValue(line, ModelPullResponse.class); + ModelPullResponse modelPullResponse = Utils.getObjectMapper() + .readValue(line, ModelPullResponse.class); if (verbose) { logger.info(modelPullResponse.getStatus()); } @@ -129,14 +132,10 @@ public class OllamaAPI { throws IOException, OllamaBaseException, InterruptedException { String url = this.host + "/api/show"; String jsonData = String.format("{\"name\": \"%s\"}", modelName); - HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create(url)) - .header("Accept", "application/json") - .header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .POST(HttpRequest.BodyPublishers.ofString(jsonData)) - .build(); + HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)) + .header("Accept", "application/json").header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -152,22 +151,18 @@ public class OllamaAPI { * Create a custom model from a model file. Read more about custom model file creation here. * - * @param modelName the name of the custom model to be created. + * @param modelName the name of the custom model to be created. * @param modelFilePath the path to model file that exists on the Ollama server. */ public void createModel(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException { String url = this.host + "/api/create"; - String jsonData = - String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath); - HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create(url)) - .header("Accept", "application/json") - .header("Content-Type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) - .build(); + String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, + modelFilePath); + HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)) + .header("Accept", "application/json").header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) + .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -188,22 +183,18 @@ public class OllamaAPI { /** * Delete a model from Ollama server. * - * @param name the name of the model to be deleted. + * @param name the name of the model to be deleted. * @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama - * server. + * server. */ public void deleteModel(String name, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException { String url = this.host + "/api/delete"; String jsonData = String.format("{\"name\": \"%s\"}", name); - HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create(url)) - .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) - .header("Accept", "application/json") - .header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .build(); + HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)) + .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) + .header("Accept", "application/json").header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -219,7 +210,7 @@ public class OllamaAPI { /** * Generate embeddings for a given text from a model * - * @param model name of model to generate embeddings from + * @param model name of model to generate embeddings from * @param prompt text to generate embeddings for * @return embeddings */ @@ -228,20 +219,16 @@ public class OllamaAPI { String url = this.host + "/api/embeddings"; String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create(url)) - .header("Accept", "application/json") - .header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .POST(HttpRequest.BodyPublishers.ofString(jsonData)) - .build(); + HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)) + .header("Accept", "application/json").header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseBody = response.body(); if (statusCode == 200) { - EmbeddingResponse embeddingResponse = - Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class); + EmbeddingResponse embeddingResponse = Utils.getObjectMapper() + .readValue(responseBody, EmbeddingResponse.class); return embeddingResponse.getEmbedding(); } else { throw new OllamaBaseException(statusCode + " - " + responseBody); @@ -251,40 +238,119 @@ public class OllamaAPI { /** * Ask a question to a model running on Ollama server. This is a sync/blocking call. * - * @param model the ollama model to ask the question to + * @param model the ollama model to ask the question to * @param promptText the prompt/question text * @return OllamaResult - that includes response text and time taken for response */ public OllamaResult ask(String model, String promptText) throws OllamaBaseException, IOException, InterruptedException { OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText); + return askSync(ollamaRequestModel); + } + + /** + * With one or more image files, ask a question to a model running on Ollama server. This is a + * sync/blocking call. + * + * @param model the ollama model to ask the question to + * @param promptText the prompt/question text + * @param imageFiles the list of image files to use for the question + * @return OllamaResult - that includes response text and time taken for response + */ + public OllamaResult askWithImages(String model, String promptText, List imageFiles) + throws OllamaBaseException, IOException, InterruptedException { + List images = new ArrayList<>(); + for (File imageFile : imageFiles) { + images.add(encodeFileToBase64(imageFile)); + } + OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText, images); + return askSync(ollamaRequestModel); + } + + /** + * With one or more image URLs, ask a question to a model running on Ollama server. This is a + * sync/blocking call. + * + * @param model the ollama model to ask the question to + * @param promptText the prompt/question text + * @param imageURLs the list of image URLs to use for the question + * @return OllamaResult - that includes response text and time taken for response + */ + public OllamaResult askWithImageURLs(String model, String promptText, List imageURLs) + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + List images = new ArrayList<>(); + for (String imageURL : imageURLs) { + images.add(encodeByteArrayToBase64(loadImageBytesFromUrl(imageURL))); + } + OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText, images); + return askSync(ollamaRequestModel); + } + + public static String encodeFileToBase64(File file) throws IOException { + return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath())); + } + + public static String encodeByteArrayToBase64(byte[] bytes) { + return Base64.getEncoder().encodeToString(bytes); + } + + public static byte[] loadImageBytesFromUrl(String imageUrl) + throws IOException, URISyntaxException { + URL url = new URI(imageUrl).toURL(); + try (InputStream in = url.openStream(); ByteArrayOutputStream out = new ByteArrayOutputStream()) { + byte[] buffer = new byte[1024]; + int bytesRead; + while ((bytesRead = in.read(buffer)) != -1) { + out.write(buffer, 0, bytesRead); + } + return out.toByteArray(); + } + } + + /** + * Ask a question to a model running on Ollama server and get a callback handle that can be used + * to check for status and get the response from the model later. This would be an + * async/non-blocking call. + * + * @param model the ollama model to ask the question to + * @param promptText the prompt/question text + * @return the ollama async result callback handle + */ + public OllamaAsyncResultCallback askAsync(String model, String promptText) { + OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText); + HttpClient httpClient = HttpClient.newHttpClient(); + URI uri = URI.create(this.host + "/api/generate"); + OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(httpClient, + uri, ollamaRequestModel, requestTimeoutSeconds); + ollamaAsyncResultCallback.start(); + return ollamaAsyncResultCallback; + } + + private OllamaResult askSync(OllamaRequestModel ollamaRequestModel) + throws OllamaBaseException, IOException, InterruptedException { long startTime = System.currentTimeMillis(); HttpClient httpClient = HttpClient.newHttpClient(); URI uri = URI.create(this.host + "/api/generate"); - HttpRequest request = - HttpRequest.newBuilder(uri) - .POST( - HttpRequest.BodyPublishers.ofString( - Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) - .header("Content-Type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .build(); - HttpResponse response = - httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString( + Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) + .header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)).build(); + HttpResponse response = httpClient.send(request, + HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); InputStream responseBodyStream = response.body(); StringBuilder responseBuffer = new StringBuilder(); - try (BufferedReader reader = - new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { if (statusCode == 404) { - OllamaErrorResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); + OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper() + .readValue(line, OllamaErrorResponseModel.class); responseBuffer.append(ollamaResponseModel.getError()); } else { - OllamaResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); + OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper() + .readValue(line, OllamaResponseModel.class); if (!ollamaResponseModel.isDone()) { responseBuffer.append(ollamaResponseModel.getResponse()); } @@ -298,23 +364,4 @@ public class OllamaAPI { return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); } } - - /** - * Ask a question to a model running on Ollama server and get a callback handle that can be used - * to check for status and get the response from the model later. This would be an - * async/non-blocking call. - * - * @param model the ollama model to ask the question to - * @param promptText the prompt/question text - * @return the ollama async result callback handle - */ - public OllamaAsyncResultCallback askAsync(String model, String promptText) { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText); - HttpClient httpClient = HttpClient.newHttpClient(); - URI uri = URI.create(this.host + "/api/generate"); - OllamaAsyncResultCallback ollamaAsyncResultCallback = - new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel, requestTimeoutSeconds); - ollamaAsyncResultCallback.start(); - return ollamaAsyncResultCallback; - } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java index 40f4655..043f341 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java @@ -1,15 +1,35 @@ package io.github.amithkoujalgi.ollama4j.core.models; +import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; + +import com.fasterxml.jackson.core.JsonProcessingException; +import java.util.List; import lombok.Data; @Data public class OllamaRequestModel { - private String model; - private String prompt; - public OllamaRequestModel(String model, String prompt) { - this.model = model; - this.prompt = prompt; + private String model; + private String prompt; + private List images; + + public OllamaRequestModel(String model, String prompt) { + this.model = model; + this.prompt = prompt; + } + + public OllamaRequestModel(String model, String prompt, List images) { + this.model = model; + this.prompt = prompt; + this.images = images; + } + + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); } + } }