diff --git a/docs/docs/apis-ask/chat.md b/docs/docs/apis-ask/chat.md index 00cda08..a94b51f 100644 --- a/docs/docs/apis-ask/chat.md +++ b/docs/docs/apis-ask/chat.md @@ -95,4 +95,43 @@ public class Main { ``` You will get a response similar to: -> NI. \ No newline at end of file +> NI. + +## Create a conversation about an image (requires model with image recognition skills) + +```java +public class Main { + + public static void main(String[] args) { + + String host = "http://localhost:11434/"; + + OllamaAPI ollamaAPI = new OllamaAPI(host); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(OllamaModelType.LLAVA); + + // Load Image from File and attach to user message (alternatively images could also be added via URL) + OllamaChatRequestModel requestModel = + builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", + List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + System.out.println("First answer: " + chatResult.getResponse()); + + builder.reset(); + + // Use history to ask further questions about the image or assistant answer + requestModel = + builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); + + chatResult = ollamaAPI.chat(requestModel); + System.out.println("Second answer: " + chatResult.getResponse()); + } +} +``` + +You will get a response similar to: + +> First Answer: The image shows a dog sitting on the bow of a boat that is docked in calm water. The boat has two levels, with the lower level containing seating and what appears to be an engine cover. The dog seems relaxed and comfortable on the boat, looking out over the water. The background suggests it might be late afternoon or early evening, given the warm lighting and the low position of the sun in the sky. +> +> Second Answer: Based on the image, it's difficult to definitively determine the breed of the dog. However, the dog appears to be medium-sized with a short coat and a brown coloration, which might suggest that it is a Golden Retriever or a similar breed. Without more details like ear shape and tail length, it's not possible to identify the exact breed confidently. \ No newline at end of file diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 4661653..46fd5fb 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -15,14 +15,12 @@ import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateEndpoi import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import java.io.BufferedReader; -import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.net.URI; import java.net.URISyntaxException; -import java.net.URL; import java.net.http.HttpClient; import java.net.http.HttpConnectTimeoutException; import java.net.http.HttpRequest; @@ -413,7 +411,7 @@ public class OllamaAPI { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { List images = new ArrayList<>(); for (String imageURL : imageURLs) { - images.add(encodeByteArrayToBase64(loadImageBytesFromUrl(imageURL))); + images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL))); } OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); @@ -469,20 +467,6 @@ public class OllamaAPI { return Base64.getEncoder().encodeToString(bytes); } - private static byte[] loadImageBytesFromUrl(String imageUrl) - throws IOException, URISyntaxException { - URL url = new URI(imageUrl).toURL(); - try (InputStream in = url.openStream(); - ByteArrayOutputStream out = new ByteArrayOutputStream()) { - byte[] buffer = new byte[1024]; - int bytesRead; - while ((bytesRead = in.read(buffer)) != -1) { - out.write(buffer, 0, bytesRead); - } - return out.toByteArray(); - } - } - private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatMessage.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatMessage.java index f8cca73..0b14315 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatMessage.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatMessage.java @@ -3,7 +3,10 @@ package io.github.amithkoujalgi.ollama4j.core.models.chat; import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; import com.fasterxml.jackson.core.JsonProcessingException; -import java.io.File; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + +import io.github.amithkoujalgi.ollama4j.core.utils.FileToBase64Serializer; + import java.util.List; import lombok.AllArgsConstructor; import lombok.Data; @@ -28,7 +31,8 @@ public class OllamaChatMessage { @NonNull private String content; - private List images; + @JsonSerialize(using = FileToBase64Serializer.class) + private List images; @Override public String toString() { diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java index af28b7b..5abbcde 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java @@ -1,16 +1,26 @@ package io.github.amithkoujalgi.ollama4j.core.models.chat; import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import io.github.amithkoujalgi.ollama4j.core.utils.Options; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; /** * Helper class for creating {@link OllamaChatRequestModel} objects using the builder-pattern. */ public class OllamaChatRequestBuilder { + private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class); + private OllamaChatRequestBuilder(String model, List messages){ request = new OllamaChatRequestModel(model, messages); } @@ -29,9 +39,41 @@ public class OllamaChatRequestBuilder { request = new OllamaChatRequestModel(request.getModel(), new ArrayList<>()); } - public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, File... images){ + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List images){ List messages = this.request.getMessages(); - messages.add(new OllamaChatMessage(role,content,List.of(images))); + + List binaryImages = images.stream().map(file -> { + try { + return Files.readAllBytes(file.toPath()); + } catch (IOException e) { + LOG.warn(String.format("File '%s' could not be accessed, will not add to message!",file.toPath()), e); + return new byte[0]; + } + }).collect(Collectors.toList()); + + messages.add(new OllamaChatMessage(role,content,binaryImages)); + return this; + } + + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls){ + List messages = this.request.getMessages(); + List binaryImages = null; + if(imageUrls.length>0){ + binaryImages = new ArrayList<>(); + for (String imageUrl : imageUrls) { + try{ + binaryImages.add(Utils.loadImageBytesFromUrl(imageUrl)); + } + catch (URISyntaxException e){ + LOG.warn(String.format("URL '%s' could not be accessed, will not add to message!",imageUrl), e); + } + catch (IOException e){ + LOG.warn(String.format("Content of URL '%s' could not be read, will not add to message!",imageUrl), e); + } + } + } + + messages.add(new OllamaChatMessage(role,content,binaryImages)); return this; } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResponseModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResponseModel.java index da69dfe..4d0b027 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResponseModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResponseModel.java @@ -11,6 +11,7 @@ public class OllamaChatResponseModel { private @JsonProperty("created_at") String createdAt; private OllamaChatMessage message; private boolean done; + private String error; private List context; private @JsonProperty("total_duration") Long totalDuration; private @JsonProperty("load_duration") Long loadDuration; diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java index 134056f..93b2b2f 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java @@ -90,6 +90,11 @@ public abstract class OllamaEndpointCaller { Utils.getObjectMapper() .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); responseBuffer.append(ollamaResponseModel.getError()); + } else if (statusCode == 400) { + LOG.warn("Status code: 400 (Bad Request)"); + OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, + OllamaErrorResponseModel.class); + responseBuffer.append(ollamaResponseModel.getError()); } else { boolean finished = parseResponseAndAddToBuffer(line,responseBuffer); if (finished) { diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java new file mode 100644 index 0000000..680635b --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java @@ -0,0 +1,30 @@ +package io.github.amithkoujalgi.ollama4j.core.utils; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.util.Base64; +import java.util.Collection; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; + +public class FileToBase64Serializer extends JsonSerializer> { + + @Override + public void serialize(Collection value, JsonGenerator jsonGenerator, SerializerProvider serializers) throws IOException { + jsonGenerator.writeStartArray(); + for (byte[] file : value) { + jsonGenerator.writeString(Base64.getEncoder().encodeToString(file)); + } + jsonGenerator.writeEndArray(); + } + + public static byte[] serialize(Object obj) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + ObjectOutputStream os = new ObjectOutputStream(out); + os.writeObject(obj); + return out.toByteArray(); + } +} \ No newline at end of file diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Utils.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Utils.java index 9be49e1..1504c1d 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Utils.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Utils.java @@ -1,9 +1,30 @@ package io.github.amithkoujalgi.ollama4j.core.utils; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; + import com.fasterxml.jackson.databind.ObjectMapper; public class Utils { public static ObjectMapper getObjectMapper() { return new ObjectMapper(); } + + public static byte[] loadImageBytesFromUrl(String imageUrl) + throws IOException, URISyntaxException { + URL url = new URI(imageUrl).toURL(); + try (InputStream in = url.openStream(); + ByteArrayOutputStream out = new ByteArrayOutputStream()) { + byte[] buffer = new byte[1024]; + int bytesRead; + while ((bytesRead = in.read(buffer)) != -1) { + out.write(buffer, 0, bytesRead); + } + return out.toByteArray(); + } + } } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java index 284a52d..ed5c862 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java @@ -148,15 +148,65 @@ class TestRealAPIs { testEndpointReachability(); try { OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); - OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!") - .withMessage(OllamaChatMessageRole.USER,"What is the capital of France? And what's France's connection with Mona Lisa?") - .build(); + OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, + "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!") + .withMessage(OllamaChatMessageRole.USER, + "What is the capital of France? And what's France's connection with Mona Lisa?") + .build(); OllamaChatResult chatResult = ollamaAPI.chat(requestModel); assertNotNull(chatResult); assertFalse(chatResult.getResponse().isBlank()); assertTrue(chatResult.getResponse().startsWith("NI")); - assertEquals(3,chatResult.getChatHistory().size()); + assertEquals(3, chatResult.getChatHistory().size()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Test + @Order(3) + void testChatWithImageFromFileWithHistoryRecognition() { + testEndpointReachability(); + try { + OllamaChatRequestBuilder builder = + OllamaChatRequestBuilder.getInstance(config.getImageModel()); + OllamaChatRequestModel requestModel = + builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", + List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponse()); + + builder.reset(); + + requestModel = + builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); + + chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponse()); + + + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Test + @Order(3) + void testChatWithImageFromURL() { + testEndpointReachability(); + try { + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel()); + OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", + "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") + .build(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); } catch (IOException | OllamaBaseException | InterruptedException e) { throw new RuntimeException(e); }