From 7da4a7ffd42333a0d40da76121ad37e5cca1be4a Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Mon, 13 Nov 2023 11:32:53 +0530 Subject: [PATCH] - updated Java version to 11. - replaced Apache HTTP client code with `Java.net.http.HttpClient` --- .github/workflows/maven-publish.yml | 60 +-- .github/workflows/publish-javadoc.yml | 7 +- README.md | 4 +- pom.xml | 20 +- .../ollama4j/core/OllamaAPI.java | 350 +++++++----------- .../core/models/EmbeddingResponse.java | 1 + .../ollama4j/core/models/Model.java | 8 +- .../core/models/ModelPullResponse.java | 44 +++ .../models/OllamaAsyncResultCallback.java | 62 ++-- .../ollama4j/core/types/OllamaModelType.java | 16 +- .../ollama4j/TestMockedAPIs.java | 4 +- src/test/resources/logback.xml | 2 +- 12 files changed, 269 insertions(+), 309 deletions(-) create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelPullResponse.java diff --git a/.github/workflows/maven-publish.yml b/.github/workflows/maven-publish.yml index efdba57..781f444 100644 --- a/.github/workflows/maven-publish.yml +++ b/.github/workflows/maven-publish.yml @@ -9,7 +9,7 @@ name: Test and Publish Package on: push: - branches: ["main"] + branches: [ "main" ] workflow_dispatch: jobs: @@ -21,36 +21,36 @@ jobs: packages: write steps: - - uses: actions/checkout@v3 - - name: Set up JDK 8 - uses: actions/setup-java@v3 - with: - java-version: '8' - distribution: 'temurin' - server-id: github # Value of the distributionManagement/repository/id field of the pom.xml - settings-path: ${{ github.workspace }} # location for the settings.xml file + - uses: actions/checkout@v3 + - name: Set up JDK 8 + uses: actions/setup-java@v3 + with: + java-version: '11' + distribution: 'adopt-hotspot' + server-id: github # Value of the distributionManagement/repository/id field of the pom.xml + settings-path: ${{ github.workspace }} # location for the settings.xml file - - name: Build with Maven - run: mvn -U -B clean package --file pom.xml + - name: Build with Maven + run: mvn -U -B clean package --file pom.xml - - name: Run Tests - run: mvn -U clean verify --file pom.xml + - name: Run Tests + run: mvn -U clean verify --file pom.xml - - name: Set up Apache Maven Central (Overwrite settings.xml) - uses: actions/setup-java@v3 - with: # running setup-java again overwrites the settings.xml - java-version: 8 - distribution: 'temurin' - cache: 'maven' - server-id: ossrh - server-username: MAVEN_USERNAME - server-password: MAVEN_PASSWORD - gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }} - gpg-passphrase: MAVEN_GPG_PASSPHRASE + - name: Set up Apache Maven Central (Overwrite settings.xml) + uses: actions/setup-java@v3 + with: # running setup-java again overwrites the settings.xml + java-version: '11' + distribution: 'adopt-hotspot' + cache: 'maven' + server-id: ossrh + server-username: MAVEN_USERNAME + server-password: MAVEN_PASSWORD + gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }} + gpg-passphrase: MAVEN_GPG_PASSPHRASE - - name: Publish to GitHub Packages Apache Maven - run: mvn clean deploy -Dgpg.passphrase="${{ secrets.GPG_PASSPHRASE }}" - env: - MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }} - MAVEN_PASSWORD: ${{ secrets.OSSRH_PASSWORD }} - MAVEN_GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} \ No newline at end of file + - name: Publish to GitHub Packages Apache Maven + run: mvn clean deploy -Dgpg.passphrase="${{ secrets.GPG_PASSPHRASE }}" + env: + MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }} + MAVEN_PASSWORD: ${{ secrets.OSSRH_PASSWORD }} + MAVEN_GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} \ No newline at end of file diff --git a/.github/workflows/publish-javadoc.yml b/.github/workflows/publish-javadoc.yml index 3f3678e..9878a24 100644 --- a/.github/workflows/publish-javadoc.yml +++ b/.github/workflows/publish-javadoc.yml @@ -34,15 +34,12 @@ jobs: - name: Set up JDK 8 uses: actions/setup-java@v3 with: - java-version: '8' - distribution: 'temurin' + java-version: '11' + distribution: 'adopt-hotspot' server-id: github # Value of the distributionManagement/repository/id field of the pom.xml settings-path: ${{ github.workspace }} # location for the settings.xml file - name: Build with Maven run: mvn -U -B clean package --file pom.xml - - # - name: Checkout - # uses: actions/checkout@v3 - name: Setup Pages uses: actions/configure-pages@v3 - name: Upload artifact diff --git a/README.md b/README.md index 3a27473..bfb3aa5 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ A Java library (wrapper/binding) for [Ollama](https://github.com/jmorganca/ollam #### Requirements - Ollama (Either [natively](https://ollama.ai/download) setup or via [Docker](https://hub.docker.com/r/ollama/ollama)) -- Java 8 or above +- Java 11 or above #### Installation @@ -322,7 +322,7 @@ Find the full `Javadoc` (API specifications) [here](https://amithkoujalgi.github - [x] Use Java-naming conventions for attributes in the request/response models instead of the snake-case conventions. ( possibly with Jackson-mapper's `@JsonProperty`) -- [ ] Fix deprecated HTTP client code +- [x] Fix deprecated HTTP client code - [ ] Add additional params for `ask` APIs such as: - `options`: additional model parameters for the Modelfile such as `temperature` - `system`: system prompt to (overrides what is defined in the Modelfile) diff --git a/pom.xml b/pom.xml index a8b3b50..f436728 100644 --- a/pom.xml +++ b/pom.xml @@ -1,6 +1,6 @@ - 4.0.0 @@ -9,8 +9,8 @@ 1.0-SNAPSHOT - 8 - 8 + 11 + 11 UTF-8 @@ -19,7 +19,7 @@ Amith Koujalgi koujalgi.amith@gmail.com Sonatype - http://www.sonatype.com + https://www.sonatype.com @@ -99,7 +99,6 @@ - com.fasterxml.jackson.core @@ -113,9 +112,9 @@ test - org.apache.httpcomponents.client5 - httpclient5 - 5.2.1 + org.slf4j + slf4j-api + 2.0.9 org.junit.jupiter @@ -129,11 +128,8 @@ 4.1.0 test - - - ossrh diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 98bb327..afc0511 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -3,33 +3,25 @@ package io.github.amithkoujalgi.ollama4j.core; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.*; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; -import org.apache.hc.client5.http.classic.methods.HttpDelete; -import org.apache.hc.client5.http.classic.methods.HttpGet; -import org.apache.hc.client5.http.classic.methods.HttpPost; -import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; -import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; -import org.apache.hc.client5.http.impl.classic.HttpClients; -import org.apache.hc.core5.http.HttpEntity; -import org.apache.hc.core5.http.ParseException; -import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.hc.core5.http.io.entity.StringEntity; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; import java.io.InputStreamReader; -import java.io.OutputStream; import java.net.HttpURLConnection; -import java.net.URL; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; import java.nio.charset.StandardCharsets; import java.util.List; -import java.util.stream.Collectors; /** * The base Ollama API class. */ -@SuppressWarnings({"DuplicatedCode", "ExtractMethodRecommender"}) public class OllamaAPI { private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); @@ -61,28 +53,47 @@ public class OllamaAPI { * List available models from Ollama server. * * @return the list - * @throws IOException - * @throws OllamaBaseException - * @throws ParseException */ - public List listModels() throws IOException, OllamaBaseException, ParseException { + public List listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = this.host + "/api/tags"; - final HttpGet httpGet = new HttpGet(url); - httpGet.setHeader("Accept", "application/json"); - httpGet.setHeader("Content-type", "application/json"); - try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpGet)) { - final int statusCode = response.getCode(); - HttpEntity responseEntity = response.getEntity(); - String responseString = ""; - if (responseEntity != null) { - responseString = EntityUtils.toString(responseEntity, "UTF-8"); - } - if (statusCode == 200) { - return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class).getModels(); - } else { - throw new OllamaBaseException(statusCode + " - " + responseString); + HttpClient httpClient = HttpClient.newHttpClient(); + HttpRequest httpRequest = HttpRequest.newBuilder().uri(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); + HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); + int statusCode = response.statusCode(); + String responseString = response.body(); + if (statusCode == 200) { + return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class).getModels(); + } else { + throw new OllamaBaseException(statusCode + " - " + responseString); + } + } + + /** + * Pull a model on the Ollama server from the list of available models. + * + * @param model the name of the model + */ + public void pullModel(String model) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + String url = this.host + "/api/pull"; + String jsonData = String.format("{\"name\": \"%s\"}", model); + HttpRequest request = HttpRequest.newBuilder().uri(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json").header("Content-type", "application/json").build(); + HttpClient client = HttpClient.newHttpClient(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); + int statusCode = response.statusCode(); + InputStream responseBodyStream = response.body(); + String responseString = ""; + try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class); + if (verbose) { + logger.info(modelPullResponse.getStatus()); + } } } + if (statusCode != 200) { + throw new OllamaBaseException(statusCode + " - " + responseString); + } } /** @@ -90,63 +101,23 @@ public class OllamaAPI { * * @param modelName the model * @return the model details - * @throws IOException - * @throws OllamaBaseException - * @throws ParseException */ - public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, ParseException { + public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException { String url = this.host + "/api/show"; String jsonData = String.format("{\"name\": \"%s\"}", modelName); - final HttpPost httpPost = new HttpPost(url); - final StringEntity entity = new StringEntity(jsonData); - httpPost.setEntity(entity); - httpPost.setHeader("Accept", "application/json"); - httpPost.setHeader("Content-type", "application/json"); - try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) { - final int statusCode = response.getCode(); - HttpEntity responseEntity = response.getEntity(); - String responseString = ""; - if (responseEntity != null) { - responseString = EntityUtils.toString(responseEntity, "UTF-8"); - } - if (statusCode == 200) { - return Utils.getObjectMapper().readValue(responseString, ModelDetail.class); - } else { - throw new OllamaBaseException(statusCode + " - " + responseString); - } - } - } - /** - * Pull a model on the Ollama server from the list of available models. - * - * @param model the name of the model - * @throws IOException - * @throws ParseException - * @throws OllamaBaseException - */ - public void pullModel(String model) throws IOException, ParseException, OllamaBaseException { - List models = listModels().stream().filter(m -> m.getModelName().split(":")[0].equals(model)).collect(Collectors.toList()); - if (!models.isEmpty()) { - return; - } - String url = this.host + "/api/pull"; - String jsonData = String.format("{\"name\": \"%s\"}", model); - final HttpPost httpPost = new HttpPost(url); - final StringEntity entity = new StringEntity(jsonData); - httpPost.setEntity(entity); - httpPost.setHeader("Accept", "application/json"); - httpPost.setHeader("Content-type", "application/json"); - try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) { - final int statusCode = response.getCode(); - HttpEntity responseEntity = response.getEntity(); - String responseString = ""; - if (responseEntity != null) { - responseString = EntityUtils.toString(responseEntity, "UTF-8"); - } - if (statusCode != 200) { - throw new OllamaBaseException(statusCode + " - " + responseString); - } + HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + + HttpClient client = HttpClient.newHttpClient(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + int statusCode = response.statusCode(); + String responseBody = response.body(); + + if (statusCode == 200) { + return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class); + } else { + throw new OllamaBaseException(statusCode + " - " + responseBody); } } @@ -156,35 +127,24 @@ public class OllamaAPI { * * @param modelName the name of the custom model to be created. * @param modelFilePath the path to model file that exists on the Ollama server. - * @throws IOException - * @throws ParseException - * @throws OllamaBaseException */ - public void createModel(String modelName, String modelFilePath) throws IOException, ParseException, OllamaBaseException { + public void createModel(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException { String url = this.host + "/api/create"; String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath); - final HttpPost httpPost = new HttpPost(url); - final StringEntity entity = new StringEntity(jsonData); - httpPost.setEntity(entity); - httpPost.setHeader("Accept", "application/json"); - httpPost.setHeader("Content-type", "application/json"); - try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) { - final int statusCode = response.getCode(); - HttpEntity responseEntity = response.getEntity(); - String responseString = ""; - if (responseEntity != null) { - responseString = EntityUtils.toString(responseEntity, "UTF-8"); - // FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this if the issue is fixed in the Ollama API server. - if (responseString.contains("error")) { - throw new OllamaBaseException(responseString); - } - if (verbose) { - logger.info(responseString); - } - } - if (statusCode != 200) { - throw new OllamaBaseException(statusCode + " - " + responseString); - } + HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); + HttpClient client = HttpClient.newHttpClient(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + int statusCode = response.statusCode(); + String responseString = response.body(); + if (statusCode != 200) { + throw new OllamaBaseException(statusCode + " - " + responseString); + } + // FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this if the issue is fixed in the Ollama API server. + if (responseString.contains("error")) { + throw new OllamaBaseException(responseString); + } + if (verbose) { + logger.info(responseString); } } @@ -193,100 +153,21 @@ public class OllamaAPI { * * @param name the name of the model to be deleted. * @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama server. - * @throws IOException - * @throws ParseException - * @throws OllamaBaseException */ - public void deleteModel(String name, boolean ignoreIfNotPresent) throws IOException, ParseException, OllamaBaseException { + public void deleteModel(String name, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException { String url = this.host + "/api/delete"; String jsonData = String.format("{\"name\": \"%s\"}", name); - final HttpDelete httpDelete = new HttpDelete(url); - final StringEntity entity = new StringEntity(jsonData); - httpDelete.setEntity(entity); - httpDelete.setHeader("Accept", "application/json"); - httpDelete.setHeader("Content-type", "application/json"); - try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpDelete)) { - final int statusCode = response.getCode(); - HttpEntity responseEntity = response.getEntity(); - String responseString = ""; - if (responseEntity != null) { - responseString = EntityUtils.toString(responseEntity, "UTF-8"); - if (verbose) { - logger.info(responseString); - } - } - if (statusCode == 404 && responseString.contains("model") && responseString.contains("not found")) { - return; - } - if (statusCode != 200) { - throw new OllamaBaseException(statusCode + " - " + responseString); - } + HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build(); + HttpClient client = HttpClient.newHttpClient(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + int statusCode = response.statusCode(); + String responseBody = response.body(); + if (statusCode == 404 && responseBody.contains("model") && responseBody.contains("not found")) { + return; } - } - - - /** - * Ask a question to a model running on Ollama server. This is a sync/blocking call. - * - * @param ollamaModelType the ollama model to ask the question to - * @param promptText the prompt/question text - * @return the response text from the model - * @throws OllamaBaseException - * @throws IOException - */ - public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); - URL obj = new URL(this.host + "/api/generate"); - HttpURLConnection con = (HttpURLConnection) obj.openConnection(); - con.setRequestMethod("POST"); - con.setDoOutput(true); - con.setRequestProperty("Content-Type", "application/json"); - String jsonReq = Utils.getObjectMapper().writeValueAsString(ollamaRequestModel); - try (OutputStream out = con.getOutputStream()) { - out.write(jsonReq.getBytes(StandardCharsets.UTF_8)); + if (statusCode != 200) { + throw new OllamaBaseException(statusCode + " - " + responseBody); } - int responseCode = con.getResponseCode(); - if (responseCode == HttpURLConnection.HTTP_OK) { - try (BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream()))) { - String inputLine; - StringBuilder response = new StringBuilder(); - while ((inputLine = in.readLine()) != null) { - OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(inputLine, OllamaResponseModel.class); - if (!ollamaResponseModel.getDone()) { - response.append(ollamaResponseModel.getResponse()); - } - } - in.close(); - return response.toString(); - } - } else { - throw new OllamaBaseException(con.getResponseCode() + " - " + con.getResponseMessage()); - } - } - - /** - * Ask a question to a model running on Ollama server and get a callback handle that can be used to check for status and get the response from the model later. - * This would be a async/non-blocking call. - * - * @param ollamaModelType the ollama model to ask the question to - * @param promptText the prompt/question text - * @return the ollama async result callback handle - * @throws IOException - */ - public OllamaAsyncResultCallback askAsync(String ollamaModelType, String promptText) throws IOException { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); - URL obj = new URL(this.host + "/api/generate"); - HttpURLConnection con = (HttpURLConnection) obj.openConnection(); - con.setRequestMethod("POST"); - con.setDoOutput(true); - con.setRequestProperty("Content-Type", "application/json"); - String jsonReq = Utils.getObjectMapper().writeValueAsString(ollamaRequestModel); - try (OutputStream out = con.getOutputStream()) { - out.write(jsonReq.getBytes(StandardCharsets.UTF_8)); - } - OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(con); - ollamaAsyncResultCallback.start(); - return ollamaAsyncResultCallback; } /** @@ -295,29 +176,58 @@ public class OllamaAPI { * @param model name of model to generate embeddings from * @param prompt text to generate embeddings for * @return embeddings - * @throws IOException - * @throws ParseException - * @throws OllamaBaseException */ - public List generateEmbeddings(String model, String prompt) throws IOException, ParseException, OllamaBaseException { + public List generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException { String url = this.host + "/api/embeddings"; String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt); - final HttpPost httpPost = new HttpPost(url); - final StringEntity entity = new StringEntity(jsonData); - httpPost.setEntity(entity); - httpPost.setHeader("Accept", "application/json"); - httpPost.setHeader("Content-type", "application/json"); - try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) { - final int statusCode = response.getCode(); - HttpEntity responseEntity = response.getEntity(); - String responseString = ""; - if (responseEntity != null) { - responseString = EntityUtils.toString(responseEntity, "UTF-8"); - EmbeddingResponse embeddingResponse = Utils.getObjectMapper().readValue(responseString, EmbeddingResponse.class); - return embeddingResponse.getEmbedding(); - } else { - throw new OllamaBaseException(statusCode + " - " + responseString); - } + HttpClient httpClient = HttpClient.newHttpClient(); + HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + int statusCode = response.statusCode(); + String responseBody = response.body(); + if (statusCode == 200) { + EmbeddingResponse embeddingResponse = Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class); + return embeddingResponse.getEmbedding(); + } else { + throw new OllamaBaseException(statusCode + " - " + responseBody); } } + + /** + * Ask a question to a model running on Ollama server. This is a sync/blocking call. + * + * @param ollamaModelType the ollama model to ask the question to + * @param promptText the prompt/question text + * @return the response text from the model + */ + public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException { + OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); + HttpClient httpClient = HttpClient.newHttpClient(); + URI uri = URI.create(this.host + "/api/generate"); + HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build(); + HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + if (response.statusCode() == HttpURLConnection.HTTP_OK) { + OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(response.body(), OllamaResponseModel.class); + return ollamaResponseModel.getResponse(); + } else { + throw new OllamaBaseException(response.statusCode() + " - " + response.body()); + } + } + + /** + * Ask a question to a model running on Ollama server and get a callback handle that can be used to check for status and get the response from the model later. + * This would be an async/non-blocking call. + * + * @param ollamaModelType the ollama model to ask the question to + * @param promptText the prompt/question text + * @return the ollama async result callback handle + */ + public OllamaAsyncResultCallback askAsyncNew(String ollamaModelType, String promptText) { + OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); + HttpClient httpClient = HttpClient.newHttpClient(); + URI uri = URI.create(this.host + "/api/generate"); + OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel); + ollamaAsyncResultCallback.start(); + return ollamaAsyncResultCallback; + } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java index 807688d..7a46e29 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; +@SuppressWarnings("unused") public class EmbeddingResponse { @JsonProperty("embedding") private List embedding; diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java index 0ae46e8..f703859 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java @@ -17,6 +17,10 @@ public class Model { return name; } + public void setName(String name) { + this.name = name; + } + /** * Returns the model name without its version * @return model name @@ -33,10 +37,6 @@ public class Model { return name.split(":")[1]; } - public void setName(String name) { - this.name = name; - } - public String getModifiedAt() { return modifiedAt; } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelPullResponse.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelPullResponse.java new file mode 100644 index 0000000..a85ccdc --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelPullResponse.java @@ -0,0 +1,44 @@ +package io.github.amithkoujalgi.ollama4j.core.models; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class ModelPullResponse { + private String status; + + private String digest; + private Long total; + private Long completed; + + public String getStatus() { + return status; + } + + public void setStatus(String status) { + this.status = status; + } + + public String getDigest() { + return digest; + } + + public void setDigest(String digest) { + this.digest = digest; + } + + public Long getTotal() { + return total; + } + + public void setTotal(Long total) { + this.total = total; + } + + public Long getCompleted() { + return completed; + } + + public void setCompleted(Long completed) { + this.completed = completed; + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java index 529e9be..6021e3a 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java @@ -5,21 +5,30 @@ import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; import java.io.InputStreamReader; -import java.net.HttpURLConnection; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; import java.util.LinkedList; import java.util.Queue; - -@SuppressWarnings("DuplicatedCode") +@SuppressWarnings("unused") public class OllamaAsyncResultCallback extends Thread { - private final HttpURLConnection connection; + private final HttpClient client; + private final URI uri; + private final OllamaRequestModel ollamaRequestModel; + private final Queue queue = new LinkedList<>(); private String result; private boolean isDone; - private final Queue queue = new LinkedList<>(); - public OllamaAsyncResultCallback(HttpURLConnection connection) { - this.connection = connection; + + public OllamaAsyncResultCallback(HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) { + this.client = client; + this.ollamaRequestModel = ollamaRequestModel; + this.uri = uri; this.isDone = false; this.result = ""; this.queue.add(""); @@ -27,28 +36,31 @@ public class OllamaAsyncResultCallback extends Thread { @Override public void run() { - int responseCode = 0; try { - responseCode = this.connection.getResponseCode(); - if (responseCode == HttpURLConnection.HTTP_OK) { - try (BufferedReader in = new BufferedReader(new InputStreamReader(this.connection.getInputStream()))) { - String inputLine; - StringBuilder response = new StringBuilder(); - while ((inputLine = in.readLine()) != null) { - OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(inputLine, OllamaResponseModel.class); - queue.add(ollamaResponseModel.getResponse()); - if (!ollamaResponseModel.getDone()) { - response.append(ollamaResponseModel.getResponse()); - } + HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); + int statusCode = response.statusCode(); + + InputStream responseBodyStream = response.body(); + String responseString = ""; + try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + String line; + StringBuilder responseBuffer = new StringBuilder(); + while ((line = reader.readLine()) != null) { + OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); + queue.add(ollamaResponseModel.getResponse()); + if (!ollamaResponseModel.getDone()) { + responseBuffer.append(ollamaResponseModel.getResponse()); } - in.close(); - this.isDone = true; - this.result = response.toString(); } - } else { - throw new OllamaBaseException(connection.getResponseCode() + " - " + connection.getResponseMessage()); + reader.close(); + this.isDone = true; + this.result = responseBuffer.toString(); } - } catch (IOException | OllamaBaseException e) { + if (statusCode != 200) { + throw new OllamaBaseException(statusCode + " - " + responseString); + } + } catch (IOException | InterruptedException | OllamaBaseException e) { this.isDone = true; this.result = "FAILED! " + e.getMessage(); } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java index d4dbbb2..674070a 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java @@ -1,12 +1,12 @@ package io.github.amithkoujalgi.ollama4j.core.types; public class OllamaModelType { - public static String LLAMA2 = "llama2"; - public static String MISTRAL = "mistral"; - public static String MEDLLAMA2 = "medllama2"; - public static String CODELLAMA = "codellama"; - public static String VICUNA = "vicuna"; - public static String ORCAMINI = "orca-mini"; - public static String SQLCODER = "sqlcoder"; - public static String WIZARDMATH = "wizard-math"; + public static final String LLAMA2 = "llama2"; + public static final String MISTRAL = "mistral"; + public static final String MEDLLAMA2 = "medllama2"; + public static final String CODELLAMA = "codellama"; + public static final String VICUNA = "vicuna"; + public static final String ORCAMINI = "orca-mini"; + public static final String SQLCODER = "sqlcoder"; + public static final String WIZARDMATH = "wizard-math"; } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java index 9d34bd4..9f40bab 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java @@ -3,11 +3,11 @@ package io.github.amithkoujalgi.ollama4j; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; -import org.apache.hc.core5.http.ParseException; import org.junit.jupiter.api.Test; import org.mockito.Mockito; import java.io.IOException; +import java.net.URISyntaxException; import static org.mockito.Mockito.*; @@ -20,7 +20,7 @@ public class TestMockedAPIs { doNothing().when(ollamaAPI).pullModel(model); ollamaAPI.pullModel(model); verify(ollamaAPI, times(1)).pullModel(model); - } catch (IOException | ParseException | OllamaBaseException e) { + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { throw new RuntimeException(e); } } diff --git a/src/test/resources/logback.xml b/src/test/resources/logback.xml index 8416636..adead55 100644 --- a/src/test/resources/logback.xml +++ b/src/test/resources/logback.xml @@ -9,7 +9,7 @@ - + \ No newline at end of file