From 8fe869afdb3b16dc361deb920f7c40913d6ad0cd Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Fri, 16 Feb 2024 13:15:24 +0000 Subject: [PATCH 1/7] Adds additional request properties and refactors common request fields to OllamaCommonRequestModel --- .../ollama4j/core/OllamaAPI.java | 11 ++-- .../models/OllamaAsyncResultCallback.java | 10 ++-- .../core/models/OllamaCommonRequestModel.java | 26 +++++++++ .../models/chat/OllamaChatRequestBuilder.java | 6 +- .../models/chat/OllamaChatRequestModel.java | 21 +++---- .../OllamaGenerateRequestBuilder.java | 55 +++++++++++++++++++ .../OllamaGenerateRequestModel.java} | 25 ++++++--- .../OllamaGenerateResponseModel.java} | 4 +- .../request/OllamaGenerateEndpointCaller.java | 4 +- .../BooleanToJsonFormatFlagSerializer.java | 18 ++++++ 10 files changed, 142 insertions(+), 38 deletions(-) create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java rename src/main/java/io/github/amithkoujalgi/ollama4j/core/models/{OllamaRequestModel.java => generate/OllamaGenerateRequestModel.java} (50%) rename src/main/java/io/github/amithkoujalgi/ollama4j/core/models/{OllamaResponseModel.java => generate/OllamaGenerateResponseModel.java} (88%) create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index e48add1..7e82ee8 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -6,6 +6,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest; @@ -345,7 +346,7 @@ public class OllamaAPI { */ public OllamaResult generate(String model, String prompt, Options options) throws OllamaBaseException, IOException, InterruptedException { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); + OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); ollamaRequestModel.setOptions(options.getOptionsMap()); return generateSyncForOllamaRequestModel(ollamaRequestModel); } @@ -360,7 +361,7 @@ public class OllamaAPI { * @return the ollama async result callback handle */ public OllamaAsyncResultCallback generateAsync(String model, String prompt) { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); + OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); URI uri = URI.create(this.host + "/api/generate"); OllamaAsyncResultCallback ollamaAsyncResultCallback = @@ -389,7 +390,7 @@ public class OllamaAPI { for (File imageFile : imageFiles) { images.add(encodeFileToBase64(imageFile)); } - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); + OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); return generateSyncForOllamaRequestModel(ollamaRequestModel); } @@ -413,7 +414,7 @@ public class OllamaAPI { for (String imageURL : imageURLs) { images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL))); } - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); + OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); return generateSyncForOllamaRequestModel(ollamaRequestModel); } @@ -486,7 +487,7 @@ public class OllamaAPI { return Base64.getEncoder().encodeToString(bytes); } - private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel) + private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); return requestCaller.callSync(ollamaRequestModel); diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java index 74b8c49..136f1c6 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java @@ -1,6 +1,8 @@ package io.github.amithkoujalgi.ollama4j.core.models; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import java.io.BufferedReader; import java.io.IOException; @@ -22,7 +24,7 @@ import lombok.Getter; @SuppressWarnings("unused") public class OllamaAsyncResultCallback extends Thread { private final HttpRequest.Builder requestBuilder; - private final OllamaRequestModel ollamaRequestModel; + private final OllamaGenerateRequestModel ollamaRequestModel; private final Queue queue = new LinkedList<>(); private String result; private boolean isDone; @@ -47,7 +49,7 @@ public class OllamaAsyncResultCallback extends Thread { public OllamaAsyncResultCallback( HttpRequest.Builder requestBuilder, - OllamaRequestModel ollamaRequestModel, + OllamaGenerateRequestModel ollamaRequestModel, long requestTimeoutSeconds) { this.requestBuilder = requestBuilder; this.ollamaRequestModel = ollamaRequestModel; @@ -87,8 +89,8 @@ public class OllamaAsyncResultCallback extends Thread { queue.add(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError()); } else { - OllamaResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); + OllamaGenerateResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); queue.add(ollamaResponseModel.getResponse()); if (!ollamaResponseModel.isDone()) { responseBuffer.append(ollamaResponseModel.getResponse()); diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java new file mode 100644 index 0000000..f554638 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java @@ -0,0 +1,26 @@ +package io.github.amithkoujalgi.ollama4j.core.models; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + +import io.github.amithkoujalgi.ollama4j.core.utils.BooleanToJsonFormatFlagSerializer; +import lombok.Data; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; + +@Data +@RequiredArgsConstructor +public abstract class OllamaCommonRequestModel { + + @NonNull + protected String model; + @JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class) + protected boolean returnFormatJson; + protected Map options; + protected String template; + protected boolean stream; + @JsonProperty(value = "keep_alive") + protected String keepAlive; +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java index 5abbcde..e07722f 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java @@ -83,12 +83,12 @@ public class OllamaChatRequestBuilder { } public OllamaChatRequestBuilder withOptions(Options options){ - this.request.setOptions(options); + this.request.setOptions(options.getOptionsMap()); return this; } - public OllamaChatRequestBuilder withFormat(String format){ - this.request.setFormat(format); + public OllamaChatRequestBuilder withGetJsonResponse(){ + this.request.setReturnFormatJson(true); return this; } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java index 2f947a2..82c9010 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java @@ -4,15 +4,14 @@ import java.util.List; import com.fasterxml.jackson.core.JsonProcessingException; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; -import io.github.amithkoujalgi.ollama4j.core.utils.Options; import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; -import lombok.AllArgsConstructor; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NonNull; -import lombok.RequiredArgsConstructor; /** * Defines a Request to use against the ollama /api/chat endpoint. @@ -22,19 +21,15 @@ import lombok.RequiredArgsConstructor; * Chat Completion */ @Data -@AllArgsConstructor -@RequiredArgsConstructor -public class OllamaChatRequestModel implements OllamaRequestBody { - - @NonNull private String model; +@EqualsAndHashCode(callSuper = true) +public class OllamaChatRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody { @NonNull private List messages; - private String format; - private Options options; - private String template; - private boolean stream; - private String keepAlive; + public OllamaChatRequestModel(String model,List messages){ + super(model); + this.messages = messages; + } @Override public String toString() { diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java new file mode 100644 index 0000000..48b4d18 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java @@ -0,0 +1,55 @@ +package io.github.amithkoujalgi.ollama4j.core.models.generate; + +import io.github.amithkoujalgi.ollama4j.core.utils.Options; + +/** + * Helper class for creating {@link io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel} + * objects using the builder-pattern. + */ +public class OllamaGenerateRequestBuilder { + + private OllamaGenerateRequestBuilder(String model, String prompt){ + request = new OllamaGenerateRequestModel(model, prompt); + } + + private OllamaGenerateRequestModel request; + + public static OllamaGenerateRequestBuilder getInstance(String model){ + return new OllamaGenerateRequestBuilder(model,""); + } + + public OllamaGenerateRequestModel build(){ + return request; + } + + public OllamaGenerateRequestBuilder withPrompt(String prompt){ + request.setPrompt(prompt); + return this; + } + + public OllamaGenerateRequestBuilder withGetJsonResponse(){ + this.request.setReturnFormatJson(true); + return this; + } + + public OllamaGenerateRequestBuilder withOptions(Options options){ + this.request.setOptions(options.getOptionsMap()); + return this; + } + + public OllamaGenerateRequestBuilder withTemplate(String template){ + this.request.setTemplate(template); + return this; + } + + public OllamaGenerateRequestBuilder withStreaming(){ + this.request.setStream(true); + return this; + } + + public OllamaGenerateRequestBuilder withKeepAlive(String keepAlive){ + this.request.setKeepAlive(keepAlive); + return this; + } + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java similarity index 50% rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java index 9c88698..861be1a 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java @@ -1,30 +1,37 @@ -package io.github.amithkoujalgi.ollama4j.core.models; +package io.github.amithkoujalgi.ollama4j.core.models.generate; import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; import com.fasterxml.jackson.core.JsonProcessingException; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import java.util.List; -import java.util.Map; import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NonNull; @Data -public class OllamaRequestModel implements OllamaRequestBody{ +@EqualsAndHashCode(callSuper = true) +public class OllamaGenerateRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody{ - private String model; + @NonNull private String prompt; private List images; - private Map options; - public OllamaRequestModel(String model, String prompt) { - this.model = model; + private String system; + private String context; + private boolean raw; + + + public OllamaGenerateRequestModel(String model, String prompt) { + super(model); this.prompt = prompt; } - public OllamaRequestModel(String model, String prompt, List images) { - this.model = model; + public OllamaGenerateRequestModel(String model, String prompt, List images) { + super(model); this.prompt = prompt; this.images = images; } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateResponseModel.java similarity index 88% rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateResponseModel.java index 9481224..a575a7a 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateResponseModel.java @@ -1,4 +1,4 @@ -package io.github.amithkoujalgi.ollama4j.core.models; +package io.github.amithkoujalgi.ollama4j.core.models.generate; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; @@ -8,7 +8,7 @@ import lombok.Data; @Data @JsonIgnoreProperties(ignoreUnknown = true) -public class OllamaResponseModel { +public class OllamaGenerateResponseModel { private String model; private @JsonProperty("created_at") String createdAt; private String response; diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java index 8d54db3..ba55159 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java @@ -6,7 +6,7 @@ import org.slf4j.LoggerFactory; import com.fasterxml.jackson.core.JsonProcessingException; import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; -import io.github.amithkoujalgi.ollama4j.core.models.OllamaResponseModel; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ @@ -25,7 +25,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ @Override protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { try { - OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); + OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); responseBuffer.append(ollamaResponseModel.getResponse()); return ollamaResponseModel.isDone(); } catch (JsonProcessingException e) { diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java new file mode 100644 index 0000000..7513fc3 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java @@ -0,0 +1,18 @@ +package io.github.amithkoujalgi.ollama4j.core.utils; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; + +public class BooleanToJsonFormatFlagSerializer extends JsonSerializer{ + + @Override + public void serialize(Boolean value, JsonGenerator gen, SerializerProvider serializers) throws IOException { + if(value){ + gen.writeString("json"); + } + } + +} From 0f73ea75ab2fa147c348f7d61859b8a3ad4dc9e0 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Fri, 16 Feb 2024 15:56:02 +0000 Subject: [PATCH 2/7] Removes unnecessary serialize method of Serializer --- .../ollama4j/core/utils/FileToBase64Serializer.java | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java index 680635b..8e862ab 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java @@ -1,8 +1,6 @@ package io.github.amithkoujalgi.ollama4j.core.utils; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.ObjectOutputStream; import java.util.Base64; import java.util.Collection; @@ -20,11 +18,4 @@ public class FileToBase64Serializer extends JsonSerializer> { } jsonGenerator.writeEndArray(); } - - public static byte[] serialize(Object obj) throws IOException { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - ObjectOutputStream os = new ObjectOutputStream(out); - os.writeObject(obj); - return out.toByteArray(); - } } \ No newline at end of file From f38a00ebdc6daf8a552c220094459b70ea7e6564 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Fri, 16 Feb 2024 15:56:32 +0000 Subject: [PATCH 3/7] Fixes BooleanToJsonFormatFlagSerializer --- .../core/utils/BooleanToJsonFormatFlagSerializer.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java index 7513fc3..210613f 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java @@ -10,9 +10,12 @@ public class BooleanToJsonFormatFlagSerializer extends JsonSerializer{ @Override public void serialize(Boolean value, JsonGenerator gen, SerializerProvider serializers) throws IOException { - if(value){ gen.writeString("json"); - } + } + + @Override + public boolean isEmpty(Boolean value){ + return !value; } } From 91aab6cbd15a682ba4c972e5f7d21372e661c50a Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Fri, 16 Feb 2024 15:57:14 +0000 Subject: [PATCH 4/7] Fixes recursive call in non streamed chat API --- .../java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 7e82ee8..20e9d3e 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -449,7 +449,7 @@ public class OllamaAPI { * @throws InterruptedException in case the server is not reachable or network issues happen */ public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{ - return chat(request); + return chat(request,null); } /** From 06c5daa2534c253295b92c592c55cfdd120c054b Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Fri, 16 Feb 2024 15:57:48 +0000 Subject: [PATCH 5/7] Adds additional properties to chat and generate requests --- .../core/models/OllamaCommonRequestModel.java | 21 +++- .../models/chat/OllamaChatRequestModel.java | 39 +++---- .../generate/OllamaGenerateRequestModel.java | 32 +++--- .../jackson/TestChatRequestSerialization.java | 106 ++++++++++++++++++ .../TestGenerateRequestSerialization.java | 85 ++++++++++++++ 5 files changed, 240 insertions(+), 43 deletions(-) create mode 100644 src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java create mode 100644 src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java index f554638..6f985ab 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java @@ -1,26 +1,35 @@ package io.github.amithkoujalgi.ollama4j.core.models; import java.util.Map; - +import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import io.github.amithkoujalgi.ollama4j.core.utils.BooleanToJsonFormatFlagSerializer; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import lombok.Data; -import lombok.NonNull; -import lombok.RequiredArgsConstructor; @Data -@RequiredArgsConstructor +@JsonInclude(JsonInclude.Include.NON_NULL) public abstract class OllamaCommonRequestModel { - @NonNull protected String model; @JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class) - protected boolean returnFormatJson; + @JsonProperty(value = "format") + protected Boolean returnFormatJson; protected Map options; protected String template; protected boolean stream; @JsonProperty(value = "keep_alive") protected String keepAlive; + + + public String toString() { + try { + return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java index 82c9010..e55bf6a 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java @@ -1,42 +1,39 @@ package io.github.amithkoujalgi.ollama4j.core.models.chat; import java.util.List; - -import com.fasterxml.jackson.core.JsonProcessingException; - import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; -import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NonNull; +import lombok.Getter; +import lombok.Setter; /** * Defines a Request to use against the ollama /api/chat endpoint. * - * @see Generate - * Chat Completion + * @see Generate + * Chat Completion */ -@Data -@EqualsAndHashCode(callSuper = true) +@Getter +@Setter public class OllamaChatRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody { - @NonNull private List messages; + private List messages; - public OllamaChatRequestModel(String model,List messages){ - super(model); + public OllamaChatRequestModel() {} + + public OllamaChatRequestModel(String model, List messages) { + this.model = model; this.messages = messages; } @Override - public String toString() { - try { - return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); + public boolean equals(Object o) { + if (!(o instanceof OllamaChatRequestModel)) { + return false; } + + return this.toString().equals(o.toString()); } + } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java index 861be1a..b060a4c 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java @@ -1,22 +1,18 @@ package io.github.amithkoujalgi.ollama4j.core.models.generate; -import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; - -import com.fasterxml.jackson.core.JsonProcessingException; import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import java.util.List; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NonNull; -@Data -@EqualsAndHashCode(callSuper = true) +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter public class OllamaGenerateRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody{ - @NonNull private String prompt; private List images; @@ -24,23 +20,27 @@ public class OllamaGenerateRequestModel extends OllamaCommonRequestModel impleme private String context; private boolean raw; + public OllamaGenerateRequestModel() { + } public OllamaGenerateRequestModel(String model, String prompt) { - super(model); + this.model = model; this.prompt = prompt; } public OllamaGenerateRequestModel(String model, String prompt, List images) { - super(model); + this.model = model; this.prompt = prompt; this.images = images; } - public String toString() { - try { - return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); + @Override + public boolean equals(Object o) { + if (!(o instanceof OllamaGenerateRequestModel)) { + return false; } + + return this.toString().equals(o.toString()); } + } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java new file mode 100644 index 0000000..f5fa5c9 --- /dev/null +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java @@ -0,0 +1,106 @@ +package io.github.amithkoujalgi.ollama4j.unittests.jackson; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import java.io.File; +import java.util.List; + +import org.json.JSONObject; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; +import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; + +public class TestChatRequestSerialization { + + private OllamaChatRequestBuilder builder; + + private ObjectMapper mapper = Utils.getObjectMapper(); + + @BeforeEach + public void init() { + builder = OllamaChatRequestBuilder.getInstance("DummyModel"); + } + + @Test + public void testRequestOnlyMandatoryFields() { + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", + List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build(); + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + } + + @Test + public void testRequestMultipleMessages() { + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt") + .withMessage(OllamaChatMessageRole.USER, "Some prompt") + .build(); + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + } + + @Test + public void testRequestWithMessageAndImage() { + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", + List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build(); + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + } + + @Test + public void testRequestWithOptions() { + OptionsBuilder b = new OptionsBuilder(); + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt") + .withOptions(b.setMirostat(1).build()).build(); + + String jsonRequest = serializeRequest(req); + OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest); + assertEqualsAfterUnmarshalling(deserializeRequest, req); + assertEquals(1, deserializeRequest.getOptions().get("mirostat")); + } + + @Test + public void testWithJsonFormat() { + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt") + .withGetJsonResponse().build(); + + String jsonRequest = serializeRequest(req); + // no jackson deserialization as format property is not boolean ==> omit as deserialization + // of request is never used in real code anyways + JSONObject jsonObject = new JSONObject(jsonRequest); + String requestFormatProperty = jsonObject.getString("format"); + assertEquals("json", requestFormatProperty); + } + + private String serializeRequest(OllamaChatRequestModel req) { + try { + return mapper.writeValueAsString(req); + } catch (JsonProcessingException e) { + fail("Could not serialize request!", e); + return null; + } + } + + private OllamaChatRequestModel deserializeRequest(String jsonRequest) { + try { + return mapper.readValue(jsonRequest, OllamaChatRequestModel.class); + } catch (JsonProcessingException e) { + fail("Could not deserialize jsonRequest!", e); + return null; + } + } + + private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest, + OllamaChatRequestModel req) { + assertEquals(req, unmarshalledRequest); + } + +} diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java new file mode 100644 index 0000000..7cf0513 --- /dev/null +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java @@ -0,0 +1,85 @@ +package io.github.amithkoujalgi.ollama4j.unittests.jackson; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import org.json.JSONObject; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; +import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; + +public class TestGenerateRequestSerialization { + + private OllamaGenerateRequestBuilder builder; + + private ObjectMapper mapper = Utils.getObjectMapper(); + + @BeforeEach + public void init() { + builder = OllamaGenerateRequestBuilder.getInstance("DummyModel"); + } + + @Test + public void testRequestOnlyMandatoryFields() { + OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build(); + + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + } + + @Test + public void testRequestWithOptions() { + OptionsBuilder b = new OptionsBuilder(); + OllamaGenerateRequestModel req = + builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build(); + + String jsonRequest = serializeRequest(req); + OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest); + assertEqualsAfterUnmarshalling(deserializeRequest, req); + assertEquals(1, deserializeRequest.getOptions().get("mirostat")); + } + + @Test + public void testWithJsonFormat() { + OllamaGenerateRequestModel req = + builder.withPrompt("Some prompt").withGetJsonResponse().build(); + + String jsonRequest = serializeRequest(req); + // no jackson deserialization as format property is not boolean ==> omit as deserialization + // of request is never used in real code anyways + JSONObject jsonObject = new JSONObject(jsonRequest); + String requestFormatProperty = jsonObject.getString("format"); + assertEquals("json", requestFormatProperty); + } + + private String serializeRequest(OllamaGenerateRequestModel req) { + try { + return mapper.writeValueAsString(req); + } catch (JsonProcessingException e) { + fail("Could not serialize request!", e); + return null; + } + } + + private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) { + try { + return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class); + } catch (JsonProcessingException e) { + fail("Could not deserialize jsonRequest!", e); + return null; + } + } + + private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest, + OllamaGenerateRequestModel req) { + assertEquals(req, unmarshalledRequest); + } + +} From 2b700fdad857e92f03b1d7cdafc9d79b067d244f Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Fri, 16 Feb 2024 15:58:48 +0000 Subject: [PATCH 6/7] Adds missing pom dependency for JSON comparison tests --- pom.xml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pom.xml b/pom.xml index e71ea4a..6d28686 100644 --- a/pom.xml +++ b/pom.xml @@ -174,6 +174,12 @@ 4.1.0 test + + org.json + json + 20240205 + test + From 0f414f71a307f7ff49fc992cd205423223602288 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Fri, 16 Feb 2024 16:01:18 +0000 Subject: [PATCH 7/7] Changes isEmpty method for BooleanToJsonFormatFlagSerializer to override non deprecated supermethod --- .../ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java index 210613f..f4d4ab3 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java @@ -14,7 +14,7 @@ public class BooleanToJsonFormatFlagSerializer extends JsonSerializer{ } @Override - public boolean isEmpty(Boolean value){ + public boolean isEmpty(SerializerProvider provider,Boolean value){ return !value; }