Merge pull request #25 from AgentSchmecker/feature/chat_messages_with_images

Adds image capability to chat API
This commit is contained in:
Amith Koujalgi 2024-02-14 19:13:22 +05:30 committed by GitHub
commit 32c4231eb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 202 additions and 26 deletions

View File

@ -95,4 +95,43 @@ public class Main {
``` ```
You will get a response similar to: You will get a response similar to:
> NI. > NI.
## Create a conversation about an image (requires model with image recognition skills)
```java
public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(OllamaModelType.LLAVA);
// Load Image from File and attach to user message (alternatively images could also be added via URL)
OllamaChatRequestModel requestModel =
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
System.out.println("First answer: " + chatResult.getResponse());
builder.reset();
// Use history to ask further questions about the image or assistant answer
requestModel =
builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
chatResult = ollamaAPI.chat(requestModel);
System.out.println("Second answer: " + chatResult.getResponse());
}
}
```
You will get a response similar to:
> First Answer: The image shows a dog sitting on the bow of a boat that is docked in calm water. The boat has two levels, with the lower level containing seating and what appears to be an engine cover. The dog seems relaxed and comfortable on the boat, looking out over the water. The background suggests it might be late afternoon or early evening, given the warm lighting and the low position of the sun in the sky.
>
> Second Answer: Based on the image, it's difficult to definitively determine the breed of the dog. However, the dog appears to be medium-sized with a short coat and a brown coloration, which might suggest that it is a Golden Retriever or a similar breed. Without more details like ear shape and tail length, it's not possible to identify the exact breed confidently.

View File

@ -15,14 +15,12 @@ import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateEndpoi
import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.net.URL;
import java.net.http.HttpClient; import java.net.http.HttpClient;
import java.net.http.HttpConnectTimeoutException; import java.net.http.HttpConnectTimeoutException;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
@ -413,7 +411,7 @@ public class OllamaAPI {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
List<String> images = new ArrayList<>(); List<String> images = new ArrayList<>();
for (String imageURL : imageURLs) { for (String imageURL : imageURLs) {
images.add(encodeByteArrayToBase64(loadImageBytesFromUrl(imageURL))); images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
} }
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
@ -469,20 +467,6 @@ public class OllamaAPI {
return Base64.getEncoder().encodeToString(bytes); return Base64.getEncoder().encodeToString(bytes);
} }
private static byte[] loadImageBytesFromUrl(String imageUrl)
throws IOException, URISyntaxException {
URL url = new URI(imageUrl).toURL();
try (InputStream in = url.openStream();
ByteArrayOutputStream out = new ByteArrayOutputStream()) {
byte[] buffer = new byte[1024];
int bytesRead;
while ((bytesRead = in.read(buffer)) != -1) {
out.write(buffer, 0, bytesRead);
}
return out.toByteArray();
}
}
private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel) private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);

View File

@ -3,7 +3,10 @@ package io.github.amithkoujalgi.ollama4j.core.models.chat;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import java.io.File; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.github.amithkoujalgi.ollama4j.core.utils.FileToBase64Serializer;
import java.util.List; import java.util.List;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
@ -28,7 +31,8 @@ public class OllamaChatMessage {
@NonNull @NonNull
private String content; private String content;
private List<File> images; @JsonSerialize(using = FileToBase64Serializer.class)
private List<byte[]> images;
@Override @Override
public String toString() { public String toString() {

View File

@ -1,16 +1,26 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat; package io.github.amithkoujalgi.ollama4j.core.models.chat;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
/** /**
* Helper class for creating {@link OllamaChatRequestModel} objects using the builder-pattern. * Helper class for creating {@link OllamaChatRequestModel} objects using the builder-pattern.
*/ */
public class OllamaChatRequestBuilder { public class OllamaChatRequestBuilder {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class);
private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages){ private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages){
request = new OllamaChatRequestModel(model, messages); request = new OllamaChatRequestModel(model, messages);
} }
@ -29,9 +39,41 @@ public class OllamaChatRequestBuilder {
request = new OllamaChatRequestModel(request.getModel(), new ArrayList<>()); request = new OllamaChatRequestModel(request.getModel(), new ArrayList<>());
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, File... images){ public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images){
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
messages.add(new OllamaChatMessage(role,content,List.of(images)));
List<byte[]> binaryImages = images.stream().map(file -> {
try {
return Files.readAllBytes(file.toPath());
} catch (IOException e) {
LOG.warn(String.format("File '%s' could not be accessed, will not add to message!",file.toPath()), e);
return new byte[0];
}
}).collect(Collectors.toList());
messages.add(new OllamaChatMessage(role,content,binaryImages));
return this;
}
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls){
List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = null;
if(imageUrls.length>0){
binaryImages = new ArrayList<>();
for (String imageUrl : imageUrls) {
try{
binaryImages.add(Utils.loadImageBytesFromUrl(imageUrl));
}
catch (URISyntaxException e){
LOG.warn(String.format("URL '%s' could not be accessed, will not add to message!",imageUrl), e);
}
catch (IOException e){
LOG.warn(String.format("Content of URL '%s' could not be read, will not add to message!",imageUrl), e);
}
}
}
messages.add(new OllamaChatMessage(role,content,binaryImages));
return this; return this;
} }

View File

@ -11,6 +11,7 @@ public class OllamaChatResponseModel {
private @JsonProperty("created_at") String createdAt; private @JsonProperty("created_at") String createdAt;
private OllamaChatMessage message; private OllamaChatMessage message;
private boolean done; private boolean done;
private String error;
private List<Integer> context; private List<Integer> context;
private @JsonProperty("total_duration") Long totalDuration; private @JsonProperty("total_duration") Long totalDuration;
private @JsonProperty("load_duration") Long loadDuration; private @JsonProperty("load_duration") Long loadDuration;

View File

@ -90,6 +90,11 @@ public abstract class OllamaEndpointCaller {
Utils.getObjectMapper() Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)");
OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else { } else {
boolean finished = parseResponseAndAddToBuffer(line,responseBuffer); boolean finished = parseResponseAndAddToBuffer(line,responseBuffer);
if (finished) { if (finished) {

View File

@ -0,0 +1,30 @@
package io.github.amithkoujalgi.ollama4j.core.utils;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.Base64;
import java.util.Collection;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
public class FileToBase64Serializer extends JsonSerializer<Collection<byte[]>> {
@Override
public void serialize(Collection<byte[]> value, JsonGenerator jsonGenerator, SerializerProvider serializers) throws IOException {
jsonGenerator.writeStartArray();
for (byte[] file : value) {
jsonGenerator.writeString(Base64.getEncoder().encodeToString(file));
}
jsonGenerator.writeEndArray();
}
public static byte[] serialize(Object obj) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
ObjectOutputStream os = new ObjectOutputStream(out);
os.writeObject(obj);
return out.toByteArray();
}
}

View File

@ -1,9 +1,30 @@
package io.github.amithkoujalgi.ollama4j.core.utils; package io.github.amithkoujalgi.ollama4j.core.utils;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
public class Utils { public class Utils {
public static ObjectMapper getObjectMapper() { public static ObjectMapper getObjectMapper() {
return new ObjectMapper(); return new ObjectMapper();
} }
public static byte[] loadImageBytesFromUrl(String imageUrl)
throws IOException, URISyntaxException {
URL url = new URI(imageUrl).toURL();
try (InputStream in = url.openStream();
ByteArrayOutputStream out = new ByteArrayOutputStream()) {
byte[] buffer = new byte[1024];
int bytesRead;
while ((bytesRead = in.read(buffer)) != -1) {
out.write(buffer, 0, bytesRead);
}
return out.toByteArray();
}
}
} }

View File

@ -148,15 +148,65 @@ class TestRealAPIs {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!") OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
.withMessage(OllamaChatMessageRole.USER,"What is the capital of France? And what's France's connection with Mona Lisa?") "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
.build(); .withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank()); assertFalse(chatResult.getResponse().isBlank());
assertTrue(chatResult.getResponse().startsWith("NI")); assertTrue(chatResult.getResponse().startsWith("NI"));
assertEquals(3,chatResult.getChatHistory().size()); assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
@Order(3)
void testChatWithImageFromFileWithHistoryRecognition() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder =
OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequestModel requestModel =
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponse());
builder.reset();
requestModel =
builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponse());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
@Order(3)
void testChatWithImageFromURL() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }