- updated Java version to 11.

- replaced Apache HTTP client code with `Java.net.http.HttpClient`
This commit is contained in:
Amith Koujalgi 2023-11-13 11:32:53 +05:30
parent d2f405dc64
commit 7da4a7ffd4
12 changed files with 269 additions and 309 deletions

View File

@ -9,7 +9,7 @@ name: Test and Publish Package
on: on:
push: push:
branches: ["main"] branches: [ "main" ]
workflow_dispatch: workflow_dispatch:
jobs: jobs:
@ -21,36 +21,36 @@ jobs:
packages: write packages: write
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Set up JDK 8 - name: Set up JDK 8
uses: actions/setup-java@v3 uses: actions/setup-java@v3
with: with:
java-version: '8' java-version: '11'
distribution: 'temurin' distribution: 'adopt-hotspot'
server-id: github # Value of the distributionManagement/repository/id field of the pom.xml server-id: github # Value of the distributionManagement/repository/id field of the pom.xml
settings-path: ${{ github.workspace }} # location for the settings.xml file settings-path: ${{ github.workspace }} # location for the settings.xml file
- name: Build with Maven - name: Build with Maven
run: mvn -U -B clean package --file pom.xml run: mvn -U -B clean package --file pom.xml
- name: Run Tests - name: Run Tests
run: mvn -U clean verify --file pom.xml run: mvn -U clean verify --file pom.xml
- name: Set up Apache Maven Central (Overwrite settings.xml) - name: Set up Apache Maven Central (Overwrite settings.xml)
uses: actions/setup-java@v3 uses: actions/setup-java@v3
with: # running setup-java again overwrites the settings.xml with: # running setup-java again overwrites the settings.xml
java-version: 8 java-version: '11'
distribution: 'temurin' distribution: 'adopt-hotspot'
cache: 'maven' cache: 'maven'
server-id: ossrh server-id: ossrh
server-username: MAVEN_USERNAME server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }} gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Publish to GitHub Packages Apache Maven - name: Publish to GitHub Packages Apache Maven
run: mvn clean deploy -Dgpg.passphrase="${{ secrets.GPG_PASSPHRASE }}" run: mvn clean deploy -Dgpg.passphrase="${{ secrets.GPG_PASSPHRASE }}"
env: env:
MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }} MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
MAVEN_PASSWORD: ${{ secrets.OSSRH_PASSWORD }} MAVEN_PASSWORD: ${{ secrets.OSSRH_PASSWORD }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} MAVEN_GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }}

View File

@ -34,15 +34,12 @@ jobs:
- name: Set up JDK 8 - name: Set up JDK 8
uses: actions/setup-java@v3 uses: actions/setup-java@v3
with: with:
java-version: '8' java-version: '11'
distribution: 'temurin' distribution: 'adopt-hotspot'
server-id: github # Value of the distributionManagement/repository/id field of the pom.xml server-id: github # Value of the distributionManagement/repository/id field of the pom.xml
settings-path: ${{ github.workspace }} # location for the settings.xml file settings-path: ${{ github.workspace }} # location for the settings.xml file
- name: Build with Maven - name: Build with Maven
run: mvn -U -B clean package --file pom.xml run: mvn -U -B clean package --file pom.xml
# - name: Checkout
# uses: actions/checkout@v3
- name: Setup Pages - name: Setup Pages
uses: actions/configure-pages@v3 uses: actions/configure-pages@v3
- name: Upload artifact - name: Upload artifact

View File

@ -26,7 +26,7 @@ A Java library (wrapper/binding) for [Ollama](https://github.com/jmorganca/ollam
#### Requirements #### Requirements
- Ollama (Either [natively](https://ollama.ai/download) setup or via [Docker](https://hub.docker.com/r/ollama/ollama)) - Ollama (Either [natively](https://ollama.ai/download) setup or via [Docker](https://hub.docker.com/r/ollama/ollama))
- Java 8 or above - Java 11 or above
#### Installation #### Installation
@ -322,7 +322,7 @@ Find the full `Javadoc` (API specifications) [here](https://amithkoujalgi.github
- [x] Use Java-naming conventions for attributes in the request/response models instead of the snake-case conventions. ( - [x] Use Java-naming conventions for attributes in the request/response models instead of the snake-case conventions. (
possibly with Jackson-mapper's `@JsonProperty`) possibly with Jackson-mapper's `@JsonProperty`)
- [ ] Fix deprecated HTTP client code - [x] Fix deprecated HTTP client code
- [ ] Add additional params for `ask` APIs such as: - [ ] Add additional params for `ask` APIs such as:
- `options`: additional model parameters for the Modelfile such as `temperature` - `options`: additional model parameters for the Modelfile such as `temperature`
- `system`: system prompt to (overrides what is defined in the Modelfile) - `system`: system prompt to (overrides what is defined in the Modelfile)

20
pom.xml
View File

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" <project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
@ -9,8 +9,8 @@
<version>1.0-SNAPSHOT</version> <version>1.0-SNAPSHOT</version>
<properties> <properties>
<maven.compiler.source>8</maven.compiler.source> <maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target> <maven.compiler.target>11</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties> </properties>
@ -19,7 +19,7 @@
<name>Amith Koujalgi</name> <name>Amith Koujalgi</name>
<email>koujalgi.amith@gmail.com</email> <email>koujalgi.amith@gmail.com</email>
<organization>Sonatype</organization> <organization>Sonatype</organization>
<organizationUrl>http://www.sonatype.com</organizationUrl> <organizationUrl>https://www.sonatype.com</organizationUrl>
</developer> </developer>
</developers> </developers>
@ -99,7 +99,6 @@
</plugins> </plugins>
</build> </build>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>com.fasterxml.jackson.core</groupId> <groupId>com.fasterxml.jackson.core</groupId>
@ -113,9 +112,9 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.httpcomponents.client5</groupId> <groupId>org.slf4j</groupId>
<artifactId>httpclient5</artifactId> <artifactId>slf4j-api</artifactId>
<version>5.2.1</version> <version>2.0.9</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.junit.jupiter</groupId> <groupId>org.junit.jupiter</groupId>
@ -129,11 +128,8 @@
<version>4.1.0</version> <version>4.1.0</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>
<distributionManagement> <distributionManagement>
<snapshotRepository> <snapshotRepository>
<id>ossrh</id> <id>ossrh</id>

View File

@ -3,33 +3,25 @@ package io.github.amithkoujalgi.ollama4j.core;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.*; import io.github.amithkoujalgi.ollama4j.core.models.*;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.apache.hc.client5.http.classic.methods.HttpDelete;
import org.apache.hc.client5.http.classic.methods.HttpGet;
import org.apache.hc.client5.http.classic.methods.HttpPost;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse;
import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.ParseException;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.http.io.entity.StringEntity;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection; import java.net.HttpURLConnection;
import java.net.URL; import java.net.URI;
import java.net.URISyntaxException;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
/** /**
* The base Ollama API class. * The base Ollama API class.
*/ */
@SuppressWarnings({"DuplicatedCode", "ExtractMethodRecommender"})
public class OllamaAPI { public class OllamaAPI {
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
@ -61,28 +53,47 @@ public class OllamaAPI {
* List available models from Ollama server. * List available models from Ollama server.
* *
* @return the list * @return the list
* @throws IOException
* @throws OllamaBaseException
* @throws ParseException
*/ */
public List<Model> listModels() throws IOException, OllamaBaseException, ParseException { public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = this.host + "/api/tags"; String url = this.host + "/api/tags";
final HttpGet httpGet = new HttpGet(url); HttpClient httpClient = HttpClient.newHttpClient();
httpGet.setHeader("Accept", "application/json"); HttpRequest httpRequest = HttpRequest.newBuilder().uri(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
httpGet.setHeader("Content-type", "application/json"); HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpGet)) { int statusCode = response.statusCode();
final int statusCode = response.getCode(); String responseString = response.body();
HttpEntity responseEntity = response.getEntity(); if (statusCode == 200) {
String responseString = ""; return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class).getModels();
if (responseEntity != null) { } else {
responseString = EntityUtils.toString(responseEntity, "UTF-8"); throw new OllamaBaseException(statusCode + " - " + responseString);
} }
if (statusCode == 200) { }
return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class).getModels();
} else { /**
throw new OllamaBaseException(statusCode + " - " + responseString); * Pull a model on the Ollama server from the list of <a href="https://ollama.ai/library">available models</a>.
*
* @param model the name of the model
*/
public void pullModel(String model) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String url = this.host + "/api/pull";
String jsonData = String.format("{\"name\": \"%s\"}", model);
HttpRequest request = HttpRequest.newBuilder().uri(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json").header("Content-type", "application/json").build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
String responseString = "";
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
if (verbose) {
logger.info(modelPullResponse.getStatus());
}
} }
} }
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
} }
/** /**
@ -90,63 +101,23 @@ public class OllamaAPI {
* *
* @param modelName the model * @param modelName the model
* @return the model details * @return the model details
* @throws IOException
* @throws OllamaBaseException
* @throws ParseException
*/ */
public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, ParseException { public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException {
String url = this.host + "/api/show"; String url = this.host + "/api/show";
String jsonData = String.format("{\"name\": \"%s\"}", modelName); String jsonData = String.format("{\"name\": \"%s\"}", modelName);
final HttpPost httpPost = new HttpPost(url);
final StringEntity entity = new StringEntity(jsonData);
httpPost.setEntity(entity);
httpPost.setHeader("Accept", "application/json");
httpPost.setHeader("Content-type", "application/json");
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) {
final int statusCode = response.getCode();
HttpEntity responseEntity = response.getEntity();
String responseString = "";
if (responseEntity != null) {
responseString = EntityUtils.toString(responseEntity, "UTF-8");
}
if (statusCode == 200) {
return Utils.getObjectMapper().readValue(responseString, ModelDetail.class);
} else {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
}
}
/** HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
* Pull a model on the Ollama server from the list of <a href="https://ollama.ai/library">available models</a>.
* HttpClient client = HttpClient.newHttpClient();
* @param model the name of the model HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
* @throws IOException
* @throws ParseException int statusCode = response.statusCode();
* @throws OllamaBaseException String responseBody = response.body();
*/
public void pullModel(String model) throws IOException, ParseException, OllamaBaseException { if (statusCode == 200) {
List<Model> models = listModels().stream().filter(m -> m.getModelName().split(":")[0].equals(model)).collect(Collectors.toList()); return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
if (!models.isEmpty()) { } else {
return; throw new OllamaBaseException(statusCode + " - " + responseBody);
}
String url = this.host + "/api/pull";
String jsonData = String.format("{\"name\": \"%s\"}", model);
final HttpPost httpPost = new HttpPost(url);
final StringEntity entity = new StringEntity(jsonData);
httpPost.setEntity(entity);
httpPost.setHeader("Accept", "application/json");
httpPost.setHeader("Content-type", "application/json");
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) {
final int statusCode = response.getCode();
HttpEntity responseEntity = response.getEntity();
String responseString = "";
if (responseEntity != null) {
responseString = EntityUtils.toString(responseEntity, "UTF-8");
}
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
} }
} }
@ -156,35 +127,24 @@ public class OllamaAPI {
* *
* @param modelName the name of the custom model to be created. * @param modelName the name of the custom model to be created.
* @param modelFilePath the path to model file that exists on the Ollama server. * @param modelFilePath the path to model file that exists on the Ollama server.
* @throws IOException
* @throws ParseException
* @throws OllamaBaseException
*/ */
public void createModel(String modelName, String modelFilePath) throws IOException, ParseException, OllamaBaseException { public void createModel(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/create"; String url = this.host + "/api/create";
String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath); String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath);
final HttpPost httpPost = new HttpPost(url); HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
final StringEntity entity = new StringEntity(jsonData); HttpClient client = HttpClient.newHttpClient();
httpPost.setEntity(entity); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
httpPost.setHeader("Accept", "application/json"); int statusCode = response.statusCode();
httpPost.setHeader("Content-type", "application/json"); String responseString = response.body();
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) { if (statusCode != 200) {
final int statusCode = response.getCode(); throw new OllamaBaseException(statusCode + " - " + responseString);
HttpEntity responseEntity = response.getEntity(); }
String responseString = ""; // FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this if the issue is fixed in the Ollama API server.
if (responseEntity != null) { if (responseString.contains("error")) {
responseString = EntityUtils.toString(responseEntity, "UTF-8"); throw new OllamaBaseException(responseString);
// FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this if the issue is fixed in the Ollama API server. }
if (responseString.contains("error")) { if (verbose) {
throw new OllamaBaseException(responseString); logger.info(responseString);
}
if (verbose) {
logger.info(responseString);
}
}
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
} }
} }
@ -193,100 +153,21 @@ public class OllamaAPI {
* *
* @param name the name of the model to be deleted. * @param name the name of the model to be deleted.
* @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama server. * @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama server.
* @throws IOException
* @throws ParseException
* @throws OllamaBaseException
*/ */
public void deleteModel(String name, boolean ignoreIfNotPresent) throws IOException, ParseException, OllamaBaseException { public void deleteModel(String name, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/delete"; String url = this.host + "/api/delete";
String jsonData = String.format("{\"name\": \"%s\"}", name); String jsonData = String.format("{\"name\": \"%s\"}", name);
final HttpDelete httpDelete = new HttpDelete(url); HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build();
final StringEntity entity = new StringEntity(jsonData); HttpClient client = HttpClient.newHttpClient();
httpDelete.setEntity(entity); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
httpDelete.setHeader("Accept", "application/json"); int statusCode = response.statusCode();
httpDelete.setHeader("Content-type", "application/json"); String responseBody = response.body();
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpDelete)) { if (statusCode == 404 && responseBody.contains("model") && responseBody.contains("not found")) {
final int statusCode = response.getCode(); return;
HttpEntity responseEntity = response.getEntity();
String responseString = "";
if (responseEntity != null) {
responseString = EntityUtils.toString(responseEntity, "UTF-8");
if (verbose) {
logger.info(responseString);
}
}
if (statusCode == 404 && responseString.contains("model") && responseString.contains("not found")) {
return;
}
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
} }
} if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseBody);
/**
* Ask a question to a model running on Ollama server. This is a sync/blocking call.
*
* @param ollamaModelType the ollama model to ask the question to
* @param promptText the prompt/question text
* @return the response text from the model
* @throws OllamaBaseException
* @throws IOException
*/
public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
URL obj = new URL(this.host + "/api/generate");
HttpURLConnection con = (HttpURLConnection) obj.openConnection();
con.setRequestMethod("POST");
con.setDoOutput(true);
con.setRequestProperty("Content-Type", "application/json");
String jsonReq = Utils.getObjectMapper().writeValueAsString(ollamaRequestModel);
try (OutputStream out = con.getOutputStream()) {
out.write(jsonReq.getBytes(StandardCharsets.UTF_8));
} }
int responseCode = con.getResponseCode();
if (responseCode == HttpURLConnection.HTTP_OK) {
try (BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream()))) {
String inputLine;
StringBuilder response = new StringBuilder();
while ((inputLine = in.readLine()) != null) {
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(inputLine, OllamaResponseModel.class);
if (!ollamaResponseModel.getDone()) {
response.append(ollamaResponseModel.getResponse());
}
}
in.close();
return response.toString();
}
} else {
throw new OllamaBaseException(con.getResponseCode() + " - " + con.getResponseMessage());
}
}
/**
* Ask a question to a model running on Ollama server and get a callback handle that can be used to check for status and get the response from the model later.
* This would be a async/non-blocking call.
*
* @param ollamaModelType the ollama model to ask the question to
* @param promptText the prompt/question text
* @return the ollama async result callback handle
* @throws IOException
*/
public OllamaAsyncResultCallback askAsync(String ollamaModelType, String promptText) throws IOException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
URL obj = new URL(this.host + "/api/generate");
HttpURLConnection con = (HttpURLConnection) obj.openConnection();
con.setRequestMethod("POST");
con.setDoOutput(true);
con.setRequestProperty("Content-Type", "application/json");
String jsonReq = Utils.getObjectMapper().writeValueAsString(ollamaRequestModel);
try (OutputStream out = con.getOutputStream()) {
out.write(jsonReq.getBytes(StandardCharsets.UTF_8));
}
OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(con);
ollamaAsyncResultCallback.start();
return ollamaAsyncResultCallback;
} }
/** /**
@ -295,29 +176,58 @@ public class OllamaAPI {
* @param model name of model to generate embeddings from * @param model name of model to generate embeddings from
* @param prompt text to generate embeddings for * @param prompt text to generate embeddings for
* @return embeddings * @return embeddings
* @throws IOException
* @throws ParseException
* @throws OllamaBaseException
*/ */
public List<Double> generateEmbeddings(String model, String prompt) throws IOException, ParseException, OllamaBaseException { public List<Double> generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/embeddings"; String url = this.host + "/api/embeddings";
String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt); String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt);
final HttpPost httpPost = new HttpPost(url); HttpClient httpClient = HttpClient.newHttpClient();
final StringEntity entity = new StringEntity(jsonData); HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
httpPost.setEntity(entity); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
httpPost.setHeader("Accept", "application/json"); int statusCode = response.statusCode();
httpPost.setHeader("Content-type", "application/json"); String responseBody = response.body();
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) { if (statusCode == 200) {
final int statusCode = response.getCode(); EmbeddingResponse embeddingResponse = Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class);
HttpEntity responseEntity = response.getEntity(); return embeddingResponse.getEmbedding();
String responseString = ""; } else {
if (responseEntity != null) { throw new OllamaBaseException(statusCode + " - " + responseBody);
responseString = EntityUtils.toString(responseEntity, "UTF-8");
EmbeddingResponse embeddingResponse = Utils.getObjectMapper().readValue(responseString, EmbeddingResponse.class);
return embeddingResponse.getEmbedding();
} else {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
} }
} }
/**
* Ask a question to a model running on Ollama server. This is a sync/blocking call.
*
* @param ollamaModelType the ollama model to ask the question to
* @param promptText the prompt/question text
* @return the response text from the model
*/
public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate");
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
if (response.statusCode() == HttpURLConnection.HTTP_OK) {
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(response.body(), OllamaResponseModel.class);
return ollamaResponseModel.getResponse();
} else {
throw new OllamaBaseException(response.statusCode() + " - " + response.body());
}
}
/**
* Ask a question to a model running on Ollama server and get a callback handle that can be used to check for status and get the response from the model later.
* This would be an async/non-blocking call.
*
* @param ollamaModelType the ollama model to ask the question to
* @param promptText the prompt/question text
* @return the ollama async result callback handle
*/
public OllamaAsyncResultCallback askAsyncNew(String ollamaModelType, String promptText) {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel);
ollamaAsyncResultCallback.start();
return ollamaAsyncResultCallback;
}
} }

View File

@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List; import java.util.List;
@SuppressWarnings("unused")
public class EmbeddingResponse { public class EmbeddingResponse {
@JsonProperty("embedding") @JsonProperty("embedding")
private List<Double> embedding; private List<Double> embedding;

View File

@ -17,6 +17,10 @@ public class Model {
return name; return name;
} }
public void setName(String name) {
this.name = name;
}
/** /**
* Returns the model name without its version * Returns the model name without its version
* @return model name * @return model name
@ -33,10 +37,6 @@ public class Model {
return name.split(":")[1]; return name.split(":")[1];
} }
public void setName(String name) {
this.name = name;
}
public String getModifiedAt() { public String getModifiedAt() {
return modifiedAt; return modifiedAt;
} }

View File

@ -0,0 +1,44 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
@JsonIgnoreProperties(ignoreUnknown = true)
public class ModelPullResponse {
private String status;
private String digest;
private Long total;
private Long completed;
public String getStatus() {
return status;
}
public void setStatus(String status) {
this.status = status;
}
public String getDigest() {
return digest;
}
public void setDigest(String digest) {
this.digest = digest;
}
public Long getTotal() {
return total;
}
public void setTotal(Long total) {
this.total = total;
}
public Long getCompleted() {
return completed;
}
public void setCompleted(Long completed) {
this.completed = completed;
}
}

View File

@ -5,21 +5,30 @@ import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.net.HttpURLConnection; import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.Queue; import java.util.Queue;
@SuppressWarnings("unused")
@SuppressWarnings("DuplicatedCode")
public class OllamaAsyncResultCallback extends Thread { public class OllamaAsyncResultCallback extends Thread {
private final HttpURLConnection connection; private final HttpClient client;
private final URI uri;
private final OllamaRequestModel ollamaRequestModel;
private final Queue<String> queue = new LinkedList<>();
private String result; private String result;
private boolean isDone; private boolean isDone;
private final Queue<String> queue = new LinkedList<>();
public OllamaAsyncResultCallback(HttpURLConnection connection) {
this.connection = connection; public OllamaAsyncResultCallback(HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) {
this.client = client;
this.ollamaRequestModel = ollamaRequestModel;
this.uri = uri;
this.isDone = false; this.isDone = false;
this.result = ""; this.result = "";
this.queue.add(""); this.queue.add("");
@ -27,28 +36,31 @@ public class OllamaAsyncResultCallback extends Thread {
@Override @Override
public void run() { public void run() {
int responseCode = 0;
try { try {
responseCode = this.connection.getResponseCode(); HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build();
if (responseCode == HttpURLConnection.HTTP_OK) { HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
try (BufferedReader in = new BufferedReader(new InputStreamReader(this.connection.getInputStream()))) { int statusCode = response.statusCode();
String inputLine;
StringBuilder response = new StringBuilder(); InputStream responseBodyStream = response.body();
while ((inputLine = in.readLine()) != null) { String responseString = "";
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(inputLine, OllamaResponseModel.class); try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
queue.add(ollamaResponseModel.getResponse()); String line;
if (!ollamaResponseModel.getDone()) { StringBuilder responseBuffer = new StringBuilder();
response.append(ollamaResponseModel.getResponse()); while ((line = reader.readLine()) != null) {
} OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
queue.add(ollamaResponseModel.getResponse());
if (!ollamaResponseModel.getDone()) {
responseBuffer.append(ollamaResponseModel.getResponse());
} }
in.close();
this.isDone = true;
this.result = response.toString();
} }
} else { reader.close();
throw new OllamaBaseException(connection.getResponseCode() + " - " + connection.getResponseMessage()); this.isDone = true;
this.result = responseBuffer.toString();
} }
} catch (IOException | OllamaBaseException e) { if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
} catch (IOException | InterruptedException | OllamaBaseException e) {
this.isDone = true; this.isDone = true;
this.result = "FAILED! " + e.getMessage(); this.result = "FAILED! " + e.getMessage();
} }

View File

@ -1,12 +1,12 @@
package io.github.amithkoujalgi.ollama4j.core.types; package io.github.amithkoujalgi.ollama4j.core.types;
public class OllamaModelType { public class OllamaModelType {
public static String LLAMA2 = "llama2"; public static final String LLAMA2 = "llama2";
public static String MISTRAL = "mistral"; public static final String MISTRAL = "mistral";
public static String MEDLLAMA2 = "medllama2"; public static final String MEDLLAMA2 = "medllama2";
public static String CODELLAMA = "codellama"; public static final String CODELLAMA = "codellama";
public static String VICUNA = "vicuna"; public static final String VICUNA = "vicuna";
public static String ORCAMINI = "orca-mini"; public static final String ORCAMINI = "orca-mini";
public static String SQLCODER = "sqlcoder"; public static final String SQLCODER = "sqlcoder";
public static String WIZARDMATH = "wizard-math"; public static final String WIZARDMATH = "wizard-math";
} }

View File

@ -3,11 +3,11 @@ package io.github.amithkoujalgi.ollama4j;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import org.apache.hc.core5.http.ParseException;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.Mockito; import org.mockito.Mockito;
import java.io.IOException; import java.io.IOException;
import java.net.URISyntaxException;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
@ -20,7 +20,7 @@ public class TestMockedAPIs {
doNothing().when(ollamaAPI).pullModel(model); doNothing().when(ollamaAPI).pullModel(model);
ollamaAPI.pullModel(model); ollamaAPI.pullModel(model);
verify(ollamaAPI, times(1)).pullModel(model); verify(ollamaAPI, times(1)).pullModel(model);
} catch (IOException | ParseException | OllamaBaseException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }

View File

@ -9,7 +9,7 @@
<root level="info"> <root level="info">
<appender-ref ref="STDOUT"/> <appender-ref ref="STDOUT"/>
</root> </root>
<logger name="org.apache" level="WARN"/> <logger name="org.apache" level="WARN"/>
<logger name="httpclient" level="WARN"/> <logger name="httpclient" level="WARN"/>
</configuration> </configuration>