diff --git a/pom.xml b/pom.xml
index e71ea4a..6d28686 100644
--- a/pom.xml
+++ b/pom.xml
@@ -174,6 +174,12 @@
4.1.0
test
+
+ org.json
+ json
+ 20240205
+ test
+
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
index e48add1..20e9d3e 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
@@ -6,6 +6,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
+import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest;
@@ -345,7 +346,7 @@ public class OllamaAPI {
*/
public OllamaResult generate(String model, String prompt, Options options)
throws OllamaBaseException, IOException, InterruptedException {
- OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt);
+ OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel);
}
@@ -360,7 +361,7 @@ public class OllamaAPI {
* @return the ollama async result callback handle
*/
public OllamaAsyncResultCallback generateAsync(String model, String prompt) {
- OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt);
+ OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback =
@@ -389,7 +390,7 @@ public class OllamaAPI {
for (File imageFile : imageFiles) {
images.add(encodeFileToBase64(imageFile));
}
- OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images);
+ OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel);
}
@@ -413,7 +414,7 @@ public class OllamaAPI {
for (String imageURL : imageURLs) {
images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
}
- OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images);
+ OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel);
}
@@ -448,7 +449,7 @@ public class OllamaAPI {
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{
- return chat(request);
+ return chat(request,null);
}
/**
@@ -486,7 +487,7 @@ public class OllamaAPI {
return Base64.getEncoder().encodeToString(bytes);
}
- private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel)
+ private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel)
throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
return requestCaller.callSync(ollamaRequestModel);
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java
index 74b8c49..136f1c6 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java
@@ -1,6 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
+import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
+import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import java.io.BufferedReader;
import java.io.IOException;
@@ -22,7 +24,7 @@ import lombok.Getter;
@SuppressWarnings("unused")
public class OllamaAsyncResultCallback extends Thread {
private final HttpRequest.Builder requestBuilder;
- private final OllamaRequestModel ollamaRequestModel;
+ private final OllamaGenerateRequestModel ollamaRequestModel;
private final Queue queue = new LinkedList<>();
private String result;
private boolean isDone;
@@ -47,7 +49,7 @@ public class OllamaAsyncResultCallback extends Thread {
public OllamaAsyncResultCallback(
HttpRequest.Builder requestBuilder,
- OllamaRequestModel ollamaRequestModel,
+ OllamaGenerateRequestModel ollamaRequestModel,
long requestTimeoutSeconds) {
this.requestBuilder = requestBuilder;
this.ollamaRequestModel = ollamaRequestModel;
@@ -87,8 +89,8 @@ public class OllamaAsyncResultCallback extends Thread {
queue.add(ollamaResponseModel.getError());
responseBuffer.append(ollamaResponseModel.getError());
} else {
- OllamaResponseModel ollamaResponseModel =
- Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
+ OllamaGenerateResponseModel ollamaResponseModel =
+ Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
queue.add(ollamaResponseModel.getResponse());
if (!ollamaResponseModel.isDone()) {
responseBuffer.append(ollamaResponseModel.getResponse());
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java
new file mode 100644
index 0000000..6f985ab
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java
@@ -0,0 +1,35 @@
+package io.github.amithkoujalgi.ollama4j.core.models;
+
+import java.util.Map;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.annotation.JsonSerialize;
+
+import io.github.amithkoujalgi.ollama4j.core.utils.BooleanToJsonFormatFlagSerializer;
+import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
+import lombok.Data;
+
+@Data
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public abstract class OllamaCommonRequestModel {
+
+ protected String model;
+ @JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class)
+ @JsonProperty(value = "format")
+ protected Boolean returnFormatJson;
+ protected Map options;
+ protected String template;
+ protected boolean stream;
+ @JsonProperty(value = "keep_alive")
+ protected String keepAlive;
+
+
+ public String toString() {
+ try {
+ return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java
deleted file mode 100644
index 9c88698..0000000
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java
+++ /dev/null
@@ -1,39 +0,0 @@
-package io.github.amithkoujalgi.ollama4j.core.models;
-
-import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
-
-import com.fasterxml.jackson.core.JsonProcessingException;
-
-import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
-
-import java.util.List;
-import java.util.Map;
-import lombok.Data;
-
-@Data
-public class OllamaRequestModel implements OllamaRequestBody{
-
- private String model;
- private String prompt;
- private List images;
- private Map options;
-
- public OllamaRequestModel(String model, String prompt) {
- this.model = model;
- this.prompt = prompt;
- }
-
- public OllamaRequestModel(String model, String prompt, List images) {
- this.model = model;
- this.prompt = prompt;
- this.images = images;
- }
-
- public String toString() {
- try {
- return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
- } catch (JsonProcessingException e) {
- throw new RuntimeException(e);
- }
- }
-}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java
index 5abbcde..e07722f 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java
@@ -83,12 +83,12 @@ public class OllamaChatRequestBuilder {
}
public OllamaChatRequestBuilder withOptions(Options options){
- this.request.setOptions(options);
+ this.request.setOptions(options.getOptionsMap());
return this;
}
- public OllamaChatRequestBuilder withFormat(String format){
- this.request.setFormat(format);
+ public OllamaChatRequestBuilder withGetJsonResponse(){
+ this.request.setReturnFormatJson(true);
return this;
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java
index 2f947a2..e55bf6a 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java
@@ -1,47 +1,39 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import java.util.List;
-
-import com.fasterxml.jackson.core.JsonProcessingException;
-
+import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
-import io.github.amithkoujalgi.ollama4j.core.utils.Options;
-import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
-
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NonNull;
-import lombok.RequiredArgsConstructor;
+import lombok.Getter;
+import lombok.Setter;
/**
* Defines a Request to use against the ollama /api/chat endpoint.
*
- * @see Generate
- * Chat Completion
+ * @see Generate
+ * Chat Completion
*/
-@Data
-@AllArgsConstructor
-@RequiredArgsConstructor
-public class OllamaChatRequestModel implements OllamaRequestBody {
+@Getter
+@Setter
+public class OllamaChatRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody {
- @NonNull private String model;
+ private List messages;
- @NonNull private List messages;
+ public OllamaChatRequestModel() {}
- private String format;
- private Options options;
- private String template;
- private boolean stream;
- private String keepAlive;
+ public OllamaChatRequestModel(String model, List messages) {
+ this.model = model;
+ this.messages = messages;
+ }
@Override
- public String toString() {
- try {
- return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
- } catch (JsonProcessingException e) {
- throw new RuntimeException(e);
+ public boolean equals(Object o) {
+ if (!(o instanceof OllamaChatRequestModel)) {
+ return false;
}
+
+ return this.toString().equals(o.toString());
}
+
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java
new file mode 100644
index 0000000..48b4d18
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java
@@ -0,0 +1,55 @@
+package io.github.amithkoujalgi.ollama4j.core.models.generate;
+
+import io.github.amithkoujalgi.ollama4j.core.utils.Options;
+
+/**
+ * Helper class for creating {@link io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel}
+ * objects using the builder-pattern.
+ */
+public class OllamaGenerateRequestBuilder {
+
+ private OllamaGenerateRequestBuilder(String model, String prompt){
+ request = new OllamaGenerateRequestModel(model, prompt);
+ }
+
+ private OllamaGenerateRequestModel request;
+
+ public static OllamaGenerateRequestBuilder getInstance(String model){
+ return new OllamaGenerateRequestBuilder(model,"");
+ }
+
+ public OllamaGenerateRequestModel build(){
+ return request;
+ }
+
+ public OllamaGenerateRequestBuilder withPrompt(String prompt){
+ request.setPrompt(prompt);
+ return this;
+ }
+
+ public OllamaGenerateRequestBuilder withGetJsonResponse(){
+ this.request.setReturnFormatJson(true);
+ return this;
+ }
+
+ public OllamaGenerateRequestBuilder withOptions(Options options){
+ this.request.setOptions(options.getOptionsMap());
+ return this;
+ }
+
+ public OllamaGenerateRequestBuilder withTemplate(String template){
+ this.request.setTemplate(template);
+ return this;
+ }
+
+ public OllamaGenerateRequestBuilder withStreaming(){
+ this.request.setStream(true);
+ return this;
+ }
+
+ public OllamaGenerateRequestBuilder withKeepAlive(String keepAlive){
+ this.request.setKeepAlive(keepAlive);
+ return this;
+ }
+
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java
new file mode 100644
index 0000000..b060a4c
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java
@@ -0,0 +1,46 @@
+package io.github.amithkoujalgi.ollama4j.core.models.generate;
+
+
+import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel;
+import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
+
+import java.util.List;
+
+import lombok.Getter;
+import lombok.Setter;
+
+@Getter
+@Setter
+public class OllamaGenerateRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody{
+
+ private String prompt;
+ private List images;
+
+ private String system;
+ private String context;
+ private boolean raw;
+
+ public OllamaGenerateRequestModel() {
+ }
+
+ public OllamaGenerateRequestModel(String model, String prompt) {
+ this.model = model;
+ this.prompt = prompt;
+ }
+
+ public OllamaGenerateRequestModel(String model, String prompt, List images) {
+ this.model = model;
+ this.prompt = prompt;
+ this.images = images;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof OllamaGenerateRequestModel)) {
+ return false;
+ }
+
+ return this.toString().equals(o.toString());
+ }
+
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateResponseModel.java
similarity index 88%
rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java
rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateResponseModel.java
index 9481224..a575a7a 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateResponseModel.java
@@ -1,4 +1,4 @@
-package io.github.amithkoujalgi.ollama4j.core.models;
+package io.github.amithkoujalgi.ollama4j.core.models.generate;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
@@ -8,7 +8,7 @@ import lombok.Data;
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
-public class OllamaResponseModel {
+public class OllamaGenerateResponseModel {
private String model;
private @JsonProperty("created_at") String createdAt;
private String response;
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
index 8d54db3..ba55159 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
@@ -6,7 +6,7 @@ import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
-import io.github.amithkoujalgi.ollama4j.core.models.OllamaResponseModel;
+import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
@@ -25,7 +25,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
@Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
try {
- OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
+ OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
responseBuffer.append(ollamaResponseModel.getResponse());
return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) {
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java
new file mode 100644
index 0000000..f4d4ab3
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java
@@ -0,0 +1,21 @@
+package io.github.amithkoujalgi.ollama4j.core.utils;
+
+import java.io.IOException;
+
+import com.fasterxml.jackson.core.JsonGenerator;
+import com.fasterxml.jackson.databind.JsonSerializer;
+import com.fasterxml.jackson.databind.SerializerProvider;
+
+public class BooleanToJsonFormatFlagSerializer extends JsonSerializer{
+
+ @Override
+ public void serialize(Boolean value, JsonGenerator gen, SerializerProvider serializers) throws IOException {
+ gen.writeString("json");
+ }
+
+ @Override
+ public boolean isEmpty(SerializerProvider provider,Boolean value){
+ return !value;
+ }
+
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java
index 680635b..8e862ab 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java
@@ -1,8 +1,6 @@
package io.github.amithkoujalgi.ollama4j.core.utils;
-import java.io.ByteArrayOutputStream;
import java.io.IOException;
-import java.io.ObjectOutputStream;
import java.util.Base64;
import java.util.Collection;
@@ -20,11 +18,4 @@ public class FileToBase64Serializer extends JsonSerializer> {
}
jsonGenerator.writeEndArray();
}
-
- public static byte[] serialize(Object obj) throws IOException {
- ByteArrayOutputStream out = new ByteArrayOutputStream();
- ObjectOutputStream os = new ObjectOutputStream(out);
- os.writeObject(obj);
- return out.toByteArray();
- }
}
\ No newline at end of file
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java
new file mode 100644
index 0000000..f5fa5c9
--- /dev/null
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java
@@ -0,0 +1,106 @@
+package io.github.amithkoujalgi.ollama4j.unittests.jackson;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
+
+import java.io.File;
+import java.util.List;
+
+import org.json.JSONObject;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
+import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
+import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
+import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
+import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
+
+public class TestChatRequestSerialization {
+
+ private OllamaChatRequestBuilder builder;
+
+ private ObjectMapper mapper = Utils.getObjectMapper();
+
+ @BeforeEach
+ public void init() {
+ builder = OllamaChatRequestBuilder.getInstance("DummyModel");
+ }
+
+ @Test
+ public void testRequestOnlyMandatoryFields() {
+ OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
+ List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
+ String jsonRequest = serializeRequest(req);
+ assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
+ }
+
+ @Test
+ public void testRequestMultipleMessages() {
+ OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt")
+ .withMessage(OllamaChatMessageRole.USER, "Some prompt")
+ .build();
+ String jsonRequest = serializeRequest(req);
+ assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
+ }
+
+ @Test
+ public void testRequestWithMessageAndImage() {
+ OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
+ List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
+ String jsonRequest = serializeRequest(req);
+ assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
+ }
+
+ @Test
+ public void testRequestWithOptions() {
+ OptionsBuilder b = new OptionsBuilder();
+ OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
+ .withOptions(b.setMirostat(1).build()).build();
+
+ String jsonRequest = serializeRequest(req);
+ OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest);
+ assertEqualsAfterUnmarshalling(deserializeRequest, req);
+ assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
+ }
+
+ @Test
+ public void testWithJsonFormat() {
+ OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
+ .withGetJsonResponse().build();
+
+ String jsonRequest = serializeRequest(req);
+ // no jackson deserialization as format property is not boolean ==> omit as deserialization
+ // of request is never used in real code anyways
+ JSONObject jsonObject = new JSONObject(jsonRequest);
+ String requestFormatProperty = jsonObject.getString("format");
+ assertEquals("json", requestFormatProperty);
+ }
+
+ private String serializeRequest(OllamaChatRequestModel req) {
+ try {
+ return mapper.writeValueAsString(req);
+ } catch (JsonProcessingException e) {
+ fail("Could not serialize request!", e);
+ return null;
+ }
+ }
+
+ private OllamaChatRequestModel deserializeRequest(String jsonRequest) {
+ try {
+ return mapper.readValue(jsonRequest, OllamaChatRequestModel.class);
+ } catch (JsonProcessingException e) {
+ fail("Could not deserialize jsonRequest!", e);
+ return null;
+ }
+ }
+
+ private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest,
+ OllamaChatRequestModel req) {
+ assertEquals(req, unmarshalledRequest);
+ }
+
+}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java
new file mode 100644
index 0000000..7cf0513
--- /dev/null
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java
@@ -0,0 +1,85 @@
+package io.github.amithkoujalgi.ollama4j.unittests.jackson;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
+
+import org.json.JSONObject;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder;
+import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
+import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
+import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
+
+public class TestGenerateRequestSerialization {
+
+ private OllamaGenerateRequestBuilder builder;
+
+ private ObjectMapper mapper = Utils.getObjectMapper();
+
+ @BeforeEach
+ public void init() {
+ builder = OllamaGenerateRequestBuilder.getInstance("DummyModel");
+ }
+
+ @Test
+ public void testRequestOnlyMandatoryFields() {
+ OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build();
+
+ String jsonRequest = serializeRequest(req);
+ assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
+ }
+
+ @Test
+ public void testRequestWithOptions() {
+ OptionsBuilder b = new OptionsBuilder();
+ OllamaGenerateRequestModel req =
+ builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build();
+
+ String jsonRequest = serializeRequest(req);
+ OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest);
+ assertEqualsAfterUnmarshalling(deserializeRequest, req);
+ assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
+ }
+
+ @Test
+ public void testWithJsonFormat() {
+ OllamaGenerateRequestModel req =
+ builder.withPrompt("Some prompt").withGetJsonResponse().build();
+
+ String jsonRequest = serializeRequest(req);
+ // no jackson deserialization as format property is not boolean ==> omit as deserialization
+ // of request is never used in real code anyways
+ JSONObject jsonObject = new JSONObject(jsonRequest);
+ String requestFormatProperty = jsonObject.getString("format");
+ assertEquals("json", requestFormatProperty);
+ }
+
+ private String serializeRequest(OllamaGenerateRequestModel req) {
+ try {
+ return mapper.writeValueAsString(req);
+ } catch (JsonProcessingException e) {
+ fail("Could not serialize request!", e);
+ return null;
+ }
+ }
+
+ private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) {
+ try {
+ return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class);
+ } catch (JsonProcessingException e) {
+ fail("Could not deserialize jsonRequest!", e);
+ return null;
+ }
+ }
+
+ private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest,
+ OllamaGenerateRequestModel req) {
+ assertEquals(req, unmarshalledRequest);
+ }
+
+}