Added APIs to pass images and ask questions on it with LLaVA model

This commit is contained in:
Amith Koujalgi 2023-12-17 15:02:30 +05:30
parent 08118b88fc
commit 251968e5e0
3 changed files with 355 additions and 214 deletions

104
README.md
View File

@ -2,7 +2,8 @@
<img src='https://raw.githubusercontent.com/amithkoujalgi/ollama4j/65a9d526150da8fcd98e2af6a164f055572bf722/ollama4j.jpeg' width='100' alt="ollama4j-icon"> <img src='https://raw.githubusercontent.com/amithkoujalgi/ollama4j/65a9d526150da8fcd98e2af6a164f055572bf722/ollama4j.jpeg' width='100' alt="ollama4j-icon">
A Java library (wrapper/binding) for [Ollama](https://github.com/jmorganca/ollama/blob/main/docs/api.md) APIs. A Java library (wrapper/binding)
for [Ollama](https://github.com/jmorganca/ollama/blob/main/docs/api.md) APIs.
```mermaid ```mermaid
flowchart LR flowchart LR
@ -38,7 +39,8 @@ A Java library (wrapper/binding) for [Ollama](https://github.com/jmorganca/ollam
#### Requirements #### Requirements
- Ollama (Either [natively](https://ollama.ai/download) setup or via [Docker](https://hub.docker.com/r/ollama/ollama)) - Ollama (Either [natively](https://ollama.ai/download) setup or
via [Docker](https://hub.docker.com/r/ollama/ollama))
- Java 11 or above - Java 11 or above
#### Installation #### Installation
@ -55,7 +57,8 @@ the [Central Repository](https://s01.oss.sonatype.org/#nexus-search;quick~ollama
</dependency> </dependency>
``` ```
You might want to include the Maven repository to pull the ollama4j library from. Include this in your `pom.xml`: You might want to include the Maven repository to pull the ollama4j library from. Include this in
your `pom.xml`:
```xml ```xml
@ -104,6 +107,7 @@ Instantiate `OllamaAPI`
```java ```java
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
@ -118,6 +122,7 @@ public class Main {
```java ```java
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
@ -132,6 +137,7 @@ _Find the list of available models from Ollama [here](https://ollama.ai/library)
```java ```java
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
@ -152,6 +158,7 @@ sqlcoder:latest
```java ```java
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
@ -176,6 +183,7 @@ Response:
```java ```java
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
@ -188,6 +196,7 @@ public class Main {
```java ```java
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
@ -201,10 +210,12 @@ public class Main {
```java ```java
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
List<Double> embeddings = ollamaAPI.generateEmbeddings(OllamaModelType.LLAMA2, "Here is an article about llamas..."); List<Double> embeddings = ollamaAPI.generateEmbeddings(OllamaModelType.LLAMA2,
"Here is an article about llamas...");
embeddings.forEach(System.out::println); embeddings.forEach(System.out::println);
} }
} }
@ -216,6 +227,7 @@ public class Main {
```java ```java
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
@ -229,10 +241,12 @@ public class Main {
```java ```java
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaAsyncResultCallback ollamaAsyncResultCallback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, "Who are you?"); OllamaAsyncResultCallback ollamaAsyncResultCallback = ollamaAPI.askAsync(OllamaModelType.LLAMA2,
"Who are you?");
while (true) { while (true) {
if (ollamaAsyncResultCallback.isComplete()) { if (ollamaAsyncResultCallback.isComplete()) {
System.out.println(ollamaAsyncResultCallback.getResponse()); System.out.println(ollamaAsyncResultCallback.getResponse());
@ -247,9 +261,12 @@ public class Main {
You'd then get a response from the model: You'd then get a response from the model:
> I am LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in a conversational > I am LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in
> manner. I am trained on a massive dataset of text from the internet and can generate human-like responses to a wide > a conversational
> range of topics and questions. I can be used to create chatbots, virtual assistants, and other applications that > manner. I am trained on a massive dataset of text from the internet and can generate human-like
> responses to a wide
> range of topics and questions. I can be used to create chatbots, virtual assistants, and other
> applications that
> require > require
> natural language understanding and generation capabilities. > natural language understanding and generation capabilities.
@ -257,6 +274,7 @@ You'd then get a response from the model:
```java ```java
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
@ -270,7 +288,8 @@ public class Main {
You'd then get a response from the model: You'd then get a response from the model:
> The 2019 ICC Cricket World Cup was held in England and Wales from May 30 to July 14, 2019. The following teams > The 2019 ICC Cricket World Cup was held in England and Wales from May 30 to July 14, 2019. The
> following teams
> participated in the tournament: > participated in the tournament:
> >
> 1. Australia > 1. Australia
@ -283,18 +302,21 @@ You'd then get a response from the model:
> 8. Sri Lanka > 8. Sri Lanka
> 9. West Indies > 9. West Indies
> >
> These teams competed in a round-robin format, with the top four teams advancing to the semi-finals. The tournament was > These teams competed in a round-robin format, with the top four teams advancing to the
> semi-finals. The tournament was
> won by the England cricket team, who defeated New Zealand in the final. > won by the England cricket team, who defeated New Zealand in the final.
#### Try asking for a Database query for your data schema: #### Try asking for a Database query for your data schema:
```java ```java
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
String prompt = SamplePrompts.getSampleDatabasePromptWithQuestion("List all customer names who have bought one or more products"); String prompt = SamplePrompts.getSampleDatabasePromptWithQuestion(
"List all customer names who have bought one or more products");
String response = ollamaAPI.ask(OllamaModelType.SQLCODER, prompt); String response = ollamaAPI.ask(OllamaModelType.SQLCODER, prompt);
System.out.println(response); System.out.println(response);
} }
@ -315,12 +337,60 @@ FROM sales
GROUP BY customers.name; GROUP BY customers.name;
``` ```
#### Try asking some questions by passing images 🖼️:
With Files:
```java
public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.setRequestTimeoutSeconds(10);
OllamaResult response = ollamaAPI.askWithImageURLs(OllamaModelType.LLAVA,
"What's in this image?",
List.of(
"/path/to/image"));
System.out.println(response);
}
}
```
With URLs:
```java
public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.setRequestTimeoutSeconds(10);
OllamaResult response = ollamaAPI.askWithImageURLs(OllamaModelType.LLAVA,
"What's in this image?",
List.of(
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"));
System.out.println(response);
}
}
```
You'd then get a response from the model:
```
This image features a white boat with brown cushions, where a dog is sitting on the back of the boat.
The dog seems to be enjoying its time outdoors, perhaps on a lake.
```
#### Async API with streaming response #### Async API with streaming response
```java ```java
@SuppressWarnings("ALL") @SuppressWarnings("ALL")
public class Main { public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
@ -369,14 +439,16 @@ make it
### Areas of improvement ### Areas of improvement
- [x] Use Java-naming conventions for attributes in the request/response models instead of the snake-case conventions. ( - [x] Use Java-naming conventions for attributes in the request/response models instead of the
snake-case conventions. (
possibly with Jackson-mapper's `@JsonProperty`) possibly with Jackson-mapper's `@JsonProperty`)
- [x] Fix deprecated HTTP client code - [x] Fix deprecated HTTP client code
- [ ] Add additional params for `ask` APIs such as: - [ ] Add additional params for `ask` APIs such as:
- `options`: additional model parameters for the Modelfile such as `temperature` - `options`: additional model parameters for the Modelfile such as `temperature`
- `system`: system prompt to (overrides what is defined in the Modelfile) - `system`: system prompt to (overrides what is defined in the Modelfile)
- `template`: the full prompt or prompt template (overrides what is defined in the Modelfile) - `template`: the full prompt or prompt template (overrides what is defined in the Modelfile)
- `context`: the context parameter returned from a previous request, which can be used to keep a short - `context`: the context parameter returned from a previous request, which can be used to keep a
short
conversational memory conversational memory
- `stream`: Add support for streaming responses from the model - `stream`: Add support for streaming responses from the model
- [x] Setup logging - [x] Setup logging
@ -386,9 +458,11 @@ make it
### Get Involved ### Get Involved
Contributions are most welcome! Whether it's reporting a bug, proposing an enhancement, or helping with code - any sort Contributions are most welcome! Whether it's reporting a bug, proposing an enhancement, or helping
with code - any sort
of contribution is much appreciated. of contribution is much appreciated.
### Credits ### Credits
The nomenclature and the icon have been adopted from the incredible [Ollama](https://ollama.ai/) project. The nomenclature and the icon have been adopted from the incredible [Ollama](https://ollama.ai/)
project.

View File

@ -4,28 +4,37 @@ import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.*; import io.github.amithkoujalgi.ollama4j.core.models.*;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.net.URL;
import java.net.http.HttpClient; import java.net.http.HttpClient;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List; import java.util.List;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
/** The base Ollama API class. */ /**
* The base Ollama API class.
*/
@SuppressWarnings("DuplicatedCode") @SuppressWarnings("DuplicatedCode")
public class OllamaAPI { public class OllamaAPI {
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
private final String host; private final String host;
private final long requestTimeoutSeconds = 3; private long requestTimeoutSeconds = 3;
private boolean verbose = false; private boolean verbose = true;
/** /**
* Instantiates the Ollama API. * Instantiates the Ollama API.
@ -40,6 +49,10 @@ public class OllamaAPI {
} }
} }
public void setRequestTimeoutSeconds(long requestTimeoutSeconds) {
this.requestTimeoutSeconds = requestTimeoutSeconds;
}
/** /**
* Set/unset logging of responses * Set/unset logging of responses
* *
@ -58,21 +71,15 @@ public class OllamaAPI {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = this.host + "/api/tags"; String url = this.host + "/api/tags";
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = HttpRequest httpRequest = HttpRequest.newBuilder().uri(new URI(url))
HttpRequest.newBuilder() .header("Accept", "application/json").header("Content-type", "application/json")
.uri(new URI(url)) .timeout(Duration.ofSeconds(requestTimeoutSeconds)).GET().build();
.header("Accept", "application/json") HttpResponse<String> response = httpClient.send(httpRequest,
.header("Content-type", "application/json") HttpResponse.BodyHandlers.ofString());
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.GET()
.build();
HttpResponse<String> response =
httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseString = response.body(); String responseString = response.body();
if (statusCode == 200) { if (statusCode == 200) {
return Utils.getObjectMapper() return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class)
.readValue(responseString, ListModelsResponse.class)
.getModels(); .getModels();
} else { } else {
throw new OllamaBaseException(statusCode + " - " + responseString); throw new OllamaBaseException(statusCode + " - " + responseString);
@ -89,26 +96,22 @@ public class OllamaAPI {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String url = this.host + "/api/pull"; String url = this.host + "/api/pull";
String jsonData = String.format("{\"name\": \"%s\"}", model); String jsonData = String.format("{\"name\": \"%s\"}", model);
HttpRequest request = HttpRequest request = HttpRequest.newBuilder().uri(new URI(url))
HttpRequest.newBuilder() .POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json")
.uri(new URI(url))
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
.header("Accept", "application/json")
.header("Content-type", "application/json") .header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds)) .timeout(Duration.ofSeconds(requestTimeoutSeconds)).build();
.build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<InputStream> response = HttpResponse<InputStream> response = client.send(request,
client.send(request, HttpResponse.BodyHandlers.ofInputStream()); HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
InputStream responseBodyStream = response.body(); InputStream responseBodyStream = response.body();
String responseString = ""; String responseString = "";
try (BufferedReader reader = try (BufferedReader reader = new BufferedReader(
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
ModelPullResponse modelPullResponse = ModelPullResponse modelPullResponse = Utils.getObjectMapper()
Utils.getObjectMapper().readValue(line, ModelPullResponse.class); .readValue(line, ModelPullResponse.class);
if (verbose) { if (verbose) {
logger.info(modelPullResponse.getStatus()); logger.info(modelPullResponse.getStatus());
} }
@ -129,14 +132,10 @@ public class OllamaAPI {
throws IOException, OllamaBaseException, InterruptedException { throws IOException, OllamaBaseException, InterruptedException {
String url = this.host + "/api/show"; String url = this.host + "/api/show";
String jsonData = String.format("{\"name\": \"%s\"}", modelName); String jsonData = String.format("{\"name\": \"%s\"}", modelName);
HttpRequest request = HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url))
HttpRequest.newBuilder() .header("Accept", "application/json").header("Content-type", "application/json")
.uri(URI.create(url))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds)) .timeout(Duration.ofSeconds(requestTimeoutSeconds))
.POST(HttpRequest.BodyPublishers.ofString(jsonData)) .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
.build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -158,16 +157,12 @@ public class OllamaAPI {
public void createModel(String modelName, String modelFilePath) public void createModel(String modelName, String modelFilePath)
throws IOException, InterruptedException, OllamaBaseException { throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/create"; String url = this.host + "/api/create";
String jsonData = String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName,
String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath); modelFilePath);
HttpRequest request = HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url))
HttpRequest.newBuilder() .header("Accept", "application/json").header("Content-Type", "application/json")
.uri(URI.create(url))
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds)) .timeout(Duration.ofSeconds(requestTimeoutSeconds))
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
.build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -196,14 +191,10 @@ public class OllamaAPI {
throws IOException, InterruptedException, OllamaBaseException { throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/delete"; String url = this.host + "/api/delete";
String jsonData = String.format("{\"name\": \"%s\"}", name); String jsonData = String.format("{\"name\": \"%s\"}", name);
HttpRequest request = HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url))
HttpRequest.newBuilder()
.uri(URI.create(url))
.method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.header("Accept", "application/json") .header("Accept", "application/json").header("Content-type", "application/json")
.header("Content-type", "application/json") .timeout(Duration.ofSeconds(requestTimeoutSeconds)).build();
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -228,20 +219,16 @@ public class OllamaAPI {
String url = this.host + "/api/embeddings"; String url = this.host + "/api/embeddings";
String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt); String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt);
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest request = HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url))
HttpRequest.newBuilder() .header("Accept", "application/json").header("Content-type", "application/json")
.uri(URI.create(url))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds)) .timeout(Duration.ofSeconds(requestTimeoutSeconds))
.POST(HttpRequest.BodyPublishers.ofString(jsonData)) .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
.build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseBody = response.body(); String responseBody = response.body();
if (statusCode == 200) { if (statusCode == 200) {
EmbeddingResponse embeddingResponse = EmbeddingResponse embeddingResponse = Utils.getObjectMapper()
Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class); .readValue(responseBody, EmbeddingResponse.class);
return embeddingResponse.getEmbedding(); return embeddingResponse.getEmbedding();
} else { } else {
throw new OllamaBaseException(statusCode + " - " + responseBody); throw new OllamaBaseException(statusCode + " - " + responseBody);
@ -258,44 +245,65 @@ public class OllamaAPI {
public OllamaResult ask(String model, String promptText) public OllamaResult ask(String model, String promptText)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText); OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText);
long startTime = System.currentTimeMillis(); return askSync(ollamaRequestModel);
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate");
HttpRequest request =
HttpRequest.newBuilder(uri)
.POST(
HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.build();
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
OllamaResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
if (!ollamaResponseModel.isDone()) {
responseBuffer.append(ollamaResponseModel.getResponse());
} }
/**
* With one or more image files, ask a question to a model running on Ollama server. This is a
* sync/blocking call.
*
* @param model the ollama model to ask the question to
* @param promptText the prompt/question text
* @param imageFiles the list of image files to use for the question
* @return OllamaResult - that includes response text and time taken for response
*/
public OllamaResult askWithImages(String model, String promptText, List<File> imageFiles)
throws OllamaBaseException, IOException, InterruptedException {
List<String> images = new ArrayList<>();
for (File imageFile : imageFiles) {
images.add(encodeFileToBase64(imageFile));
} }
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText, images);
return askSync(ollamaRequestModel);
} }
/**
* With one or more image URLs, ask a question to a model running on Ollama server. This is a
* sync/blocking call.
*
* @param model the ollama model to ask the question to
* @param promptText the prompt/question text
* @param imageURLs the list of image URLs to use for the question
* @return OllamaResult - that includes response text and time taken for response
*/
public OllamaResult askWithImageURLs(String model, String promptText, List<String> imageURLs)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
List<String> images = new ArrayList<>();
for (String imageURL : imageURLs) {
images.add(encodeByteArrayToBase64(loadImageBytesFromUrl(imageURL)));
} }
if (statusCode != 200) { OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText, images);
throw new OllamaBaseException(responseBuffer.toString()); return askSync(ollamaRequestModel);
} else { }
long endTime = System.currentTimeMillis();
return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); public static String encodeFileToBase64(File file) throws IOException {
return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
}
public static String encodeByteArrayToBase64(byte[] bytes) {
return Base64.getEncoder().encodeToString(bytes);
}
public static byte[] loadImageBytesFromUrl(String imageUrl)
throws IOException, URISyntaxException {
URL url = new URI(imageUrl).toURL();
try (InputStream in = url.openStream(); ByteArrayOutputStream out = new ByteArrayOutputStream()) {
byte[] buffer = new byte[1024];
int bytesRead;
while ((bytesRead = in.read(buffer)) != -1) {
out.write(buffer, 0, bytesRead);
}
return out.toByteArray();
} }
} }
@ -312,9 +320,48 @@ public class OllamaAPI {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText); OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText);
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback = OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(httpClient,
new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel, requestTimeoutSeconds); uri, ollamaRequestModel, requestTimeoutSeconds);
ollamaAsyncResultCallback.start(); ollamaAsyncResultCallback.start();
return ollamaAsyncResultCallback; return ollamaAsyncResultCallback;
} }
private OllamaResult askSync(OllamaRequestModel ollamaRequestModel)
throws OllamaBaseException, IOException, InterruptedException {
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate");
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds)).build();
HttpResponse<InputStream> response = httpClient.send(request,
HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper()
.readValue(line, OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper()
.readValue(line, OllamaResponseModel.class);
if (!ollamaResponseModel.isDone()) {
responseBuffer.append(ollamaResponseModel.getResponse());
}
}
}
}
if (statusCode != 200) {
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
}
}
} }

View File

@ -1,15 +1,35 @@
package io.github.amithkoujalgi.ollama4j.core.models; package io.github.amithkoujalgi.ollama4j.core.models;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.util.List;
import lombok.Data; import lombok.Data;
@Data @Data
public class OllamaRequestModel { public class OllamaRequestModel {
private String model; private String model;
private String prompt; private String prompt;
private List<String> images;
public OllamaRequestModel(String model, String prompt) { public OllamaRequestModel(String model, String prompt) {
this.model = model; this.model = model;
this.prompt = prompt; this.prompt = prompt;
} }
public OllamaRequestModel(String model, String prompt, List<String> images) {
this.model = model;
this.prompt = prompt;
this.images = images;
}
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
} }