Changes images property of ChatMessage to type byte[]

This commit is contained in:
Markus Klenke 2024-02-12 22:08:10 +00:00
parent 84a6e57f42
commit 3769386539
3 changed files with 53 additions and 8 deletions

View File

@ -33,7 +33,7 @@ public class OllamaChatMessage {
private String content; private String content;
@JsonSerialize(using = FileToBase64Serializer.class) @JsonSerialize(using = FileToBase64Serializer.class)
private List<File> images; private List<byte[]> images;
@Override @Override
public String toString() { public String toString() {

View File

@ -1,16 +1,26 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat; package io.github.amithkoujalgi.ollama4j.core.models.chat;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
/** /**
* Helper class for creating {@link OllamaChatRequestModel} objects using the builder-pattern. * Helper class for creating {@link OllamaChatRequestModel} objects using the builder-pattern.
*/ */
public class OllamaChatRequestBuilder { public class OllamaChatRequestBuilder {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class);
private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages){ private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages){
request = new OllamaChatRequestModel(model, messages); request = new OllamaChatRequestModel(model, messages);
} }
@ -29,9 +39,45 @@ public class OllamaChatRequestBuilder {
request = new OllamaChatRequestModel(request.getModel(), new ArrayList<>()); request = new OllamaChatRequestModel(request.getModel(), new ArrayList<>());
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, File... images){ public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){
return withMessage(role, content, (String)null);
}
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images){
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
messages.add(new OllamaChatMessage(role,content,List.of(images)));
List<byte[]> binaryImages = images.stream().map(file -> {
try {
return Files.readAllBytes(file.toPath());
} catch (IOException e) {
LOG.warn(String.format("File '%s' could not be accessed, will not add to message!",file.toPath()), e);
return new byte[0];
}
}).collect(Collectors.toList());
messages.add(new OllamaChatMessage(role,content,binaryImages));
return this;
}
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls){
List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = null;
if(imageUrls.length>0){
binaryImages = new ArrayList<>();
for (String imageUrl : imageUrls) {
try{
binaryImages.add(Utils.loadImageBytesFromUrl(imageUrl));
}
catch (URISyntaxException e){
LOG.warn(String.format("URL '%s' could not be accessed, will not add to message!",imageUrl), e);
}
catch (IOException e){
LOG.warn(String.format("Content of URL '%s' could not be read, will not add to message!",imageUrl), e);
}
}
}
messages.add(new OllamaChatMessage(role,content,binaryImages));
return this; return this;
} }

View File

@ -1,7 +1,6 @@
package io.github.amithkoujalgi.ollama4j.core.utils; package io.github.amithkoujalgi.ollama4j.core.utils;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.ObjectOutputStream; import java.io.ObjectOutputStream;
import java.util.Base64; import java.util.Base64;
@ -11,13 +10,13 @@ import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider; import com.fasterxml.jackson.databind.SerializerProvider;
public class FileToBase64Serializer extends JsonSerializer<Collection<File>> { public class FileToBase64Serializer extends JsonSerializer<Collection<byte[]>> {
@Override @Override
public void serialize(Collection<File> value, JsonGenerator jsonGenerator, SerializerProvider serializers) throws IOException { public void serialize(Collection<byte[]> value, JsonGenerator jsonGenerator, SerializerProvider serializers) throws IOException {
jsonGenerator.writeStartArray(); jsonGenerator.writeStartArray();
for (File file : value) { for (byte[] file : value) {
jsonGenerator.writeString(Base64.getEncoder().encodeToString(serialize(file))); jsonGenerator.writeString(Base64.getEncoder().encodeToString(file));
} }
jsonGenerator.writeEndArray(); jsonGenerator.writeEndArray();
} }