diff --git a/pom.xml b/pom.xml index e71ea4a..6d28686 100644 --- a/pom.xml +++ b/pom.xml @@ -174,6 +174,12 @@ 4.1.0 test + + org.json + json + 20240205 + test + diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index e48add1..20e9d3e 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -6,6 +6,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest; @@ -345,7 +346,7 @@ public class OllamaAPI { */ public OllamaResult generate(String model, String prompt, Options options) throws OllamaBaseException, IOException, InterruptedException { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); + OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); ollamaRequestModel.setOptions(options.getOptionsMap()); return generateSyncForOllamaRequestModel(ollamaRequestModel); } @@ -360,7 +361,7 @@ public class OllamaAPI { * @return the ollama async result callback handle */ public OllamaAsyncResultCallback generateAsync(String model, String prompt) { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); + OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); URI uri = URI.create(this.host + "/api/generate"); OllamaAsyncResultCallback ollamaAsyncResultCallback = @@ -389,7 +390,7 @@ public class OllamaAPI { for (File imageFile : imageFiles) { images.add(encodeFileToBase64(imageFile)); } - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); + OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); return generateSyncForOllamaRequestModel(ollamaRequestModel); } @@ -413,7 +414,7 @@ public class OllamaAPI { for (String imageURL : imageURLs) { images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL))); } - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); + OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); return generateSyncForOllamaRequestModel(ollamaRequestModel); } @@ -448,7 +449,7 @@ public class OllamaAPI { * @throws InterruptedException in case the server is not reachable or network issues happen */ public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{ - return chat(request); + return chat(request,null); } /** @@ -486,7 +487,7 @@ public class OllamaAPI { return Base64.getEncoder().encodeToString(bytes); } - private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel) + private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); return requestCaller.callSync(ollamaRequestModel); diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java index 74b8c49..136f1c6 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java @@ -1,6 +1,8 @@ package io.github.amithkoujalgi.ollama4j.core.models; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import java.io.BufferedReader; import java.io.IOException; @@ -22,7 +24,7 @@ import lombok.Getter; @SuppressWarnings("unused") public class OllamaAsyncResultCallback extends Thread { private final HttpRequest.Builder requestBuilder; - private final OllamaRequestModel ollamaRequestModel; + private final OllamaGenerateRequestModel ollamaRequestModel; private final Queue queue = new LinkedList<>(); private String result; private boolean isDone; @@ -47,7 +49,7 @@ public class OllamaAsyncResultCallback extends Thread { public OllamaAsyncResultCallback( HttpRequest.Builder requestBuilder, - OllamaRequestModel ollamaRequestModel, + OllamaGenerateRequestModel ollamaRequestModel, long requestTimeoutSeconds) { this.requestBuilder = requestBuilder; this.ollamaRequestModel = ollamaRequestModel; @@ -87,8 +89,8 @@ public class OllamaAsyncResultCallback extends Thread { queue.add(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError()); } else { - OllamaResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); + OllamaGenerateResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); queue.add(ollamaResponseModel.getResponse()); if (!ollamaResponseModel.isDone()) { responseBuffer.append(ollamaResponseModel.getResponse()); diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java new file mode 100644 index 0000000..6f985ab --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java @@ -0,0 +1,35 @@ +package io.github.amithkoujalgi.ollama4j.core.models; + +import java.util.Map; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + +import io.github.amithkoujalgi.ollama4j.core.utils.BooleanToJsonFormatFlagSerializer; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; +import lombok.Data; + +@Data +@JsonInclude(JsonInclude.Include.NON_NULL) +public abstract class OllamaCommonRequestModel { + + protected String model; + @JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class) + @JsonProperty(value = "format") + protected Boolean returnFormatJson; + protected Map options; + protected String template; + protected boolean stream; + @JsonProperty(value = "keep_alive") + protected String keepAlive; + + + public String toString() { + try { + return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java deleted file mode 100644 index 9c88698..0000000 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java +++ /dev/null @@ -1,39 +0,0 @@ -package io.github.amithkoujalgi.ollama4j.core.models; - -import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; - -import com.fasterxml.jackson.core.JsonProcessingException; - -import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; - -import java.util.List; -import java.util.Map; -import lombok.Data; - -@Data -public class OllamaRequestModel implements OllamaRequestBody{ - - private String model; - private String prompt; - private List images; - private Map options; - - public OllamaRequestModel(String model, String prompt) { - this.model = model; - this.prompt = prompt; - } - - public OllamaRequestModel(String model, String prompt, List images) { - this.model = model; - this.prompt = prompt; - this.images = images; - } - - public String toString() { - try { - return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } -} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java index 5abbcde..e07722f 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java @@ -83,12 +83,12 @@ public class OllamaChatRequestBuilder { } public OllamaChatRequestBuilder withOptions(Options options){ - this.request.setOptions(options); + this.request.setOptions(options.getOptionsMap()); return this; } - public OllamaChatRequestBuilder withFormat(String format){ - this.request.setFormat(format); + public OllamaChatRequestBuilder withGetJsonResponse(){ + this.request.setReturnFormatJson(true); return this; } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java index 2f947a2..e55bf6a 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java @@ -1,47 +1,39 @@ package io.github.amithkoujalgi.ollama4j.core.models.chat; import java.util.List; - -import com.fasterxml.jackson.core.JsonProcessingException; - +import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; -import io.github.amithkoujalgi.ollama4j.core.utils.Options; -import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NonNull; -import lombok.RequiredArgsConstructor; +import lombok.Getter; +import lombok.Setter; /** * Defines a Request to use against the ollama /api/chat endpoint. * - * @see Generate - * Chat Completion + * @see Generate + * Chat Completion */ -@Data -@AllArgsConstructor -@RequiredArgsConstructor -public class OllamaChatRequestModel implements OllamaRequestBody { +@Getter +@Setter +public class OllamaChatRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody { - @NonNull private String model; + private List messages; - @NonNull private List messages; + public OllamaChatRequestModel() {} - private String format; - private Options options; - private String template; - private boolean stream; - private String keepAlive; + public OllamaChatRequestModel(String model, List messages) { + this.model = model; + this.messages = messages; + } @Override - public String toString() { - try { - return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); + public boolean equals(Object o) { + if (!(o instanceof OllamaChatRequestModel)) { + return false; } + + return this.toString().equals(o.toString()); } + } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java new file mode 100644 index 0000000..48b4d18 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java @@ -0,0 +1,55 @@ +package io.github.amithkoujalgi.ollama4j.core.models.generate; + +import io.github.amithkoujalgi.ollama4j.core.utils.Options; + +/** + * Helper class for creating {@link io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel} + * objects using the builder-pattern. + */ +public class OllamaGenerateRequestBuilder { + + private OllamaGenerateRequestBuilder(String model, String prompt){ + request = new OllamaGenerateRequestModel(model, prompt); + } + + private OllamaGenerateRequestModel request; + + public static OllamaGenerateRequestBuilder getInstance(String model){ + return new OllamaGenerateRequestBuilder(model,""); + } + + public OllamaGenerateRequestModel build(){ + return request; + } + + public OllamaGenerateRequestBuilder withPrompt(String prompt){ + request.setPrompt(prompt); + return this; + } + + public OllamaGenerateRequestBuilder withGetJsonResponse(){ + this.request.setReturnFormatJson(true); + return this; + } + + public OllamaGenerateRequestBuilder withOptions(Options options){ + this.request.setOptions(options.getOptionsMap()); + return this; + } + + public OllamaGenerateRequestBuilder withTemplate(String template){ + this.request.setTemplate(template); + return this; + } + + public OllamaGenerateRequestBuilder withStreaming(){ + this.request.setStream(true); + return this; + } + + public OllamaGenerateRequestBuilder withKeepAlive(String keepAlive){ + this.request.setKeepAlive(keepAlive); + return this; + } + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java new file mode 100644 index 0000000..b060a4c --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java @@ -0,0 +1,46 @@ +package io.github.amithkoujalgi.ollama4j.core.models.generate; + + +import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel; +import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; + +import java.util.List; + +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class OllamaGenerateRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody{ + + private String prompt; + private List images; + + private String system; + private String context; + private boolean raw; + + public OllamaGenerateRequestModel() { + } + + public OllamaGenerateRequestModel(String model, String prompt) { + this.model = model; + this.prompt = prompt; + } + + public OllamaGenerateRequestModel(String model, String prompt, List images) { + this.model = model; + this.prompt = prompt; + this.images = images; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof OllamaGenerateRequestModel)) { + return false; + } + + return this.toString().equals(o.toString()); + } + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateResponseModel.java similarity index 88% rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateResponseModel.java index 9481224..a575a7a 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateResponseModel.java @@ -1,4 +1,4 @@ -package io.github.amithkoujalgi.ollama4j.core.models; +package io.github.amithkoujalgi.ollama4j.core.models.generate; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; @@ -8,7 +8,7 @@ import lombok.Data; @Data @JsonIgnoreProperties(ignoreUnknown = true) -public class OllamaResponseModel { +public class OllamaGenerateResponseModel { private String model; private @JsonProperty("created_at") String createdAt; private String response; diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java index 8d54db3..ba55159 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java @@ -6,7 +6,7 @@ import org.slf4j.LoggerFactory; import com.fasterxml.jackson.core.JsonProcessingException; import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; -import io.github.amithkoujalgi.ollama4j.core.models.OllamaResponseModel; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ @@ -25,7 +25,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ @Override protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { try { - OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); + OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); responseBuffer.append(ollamaResponseModel.getResponse()); return ollamaResponseModel.isDone(); } catch (JsonProcessingException e) { diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java new file mode 100644 index 0000000..f4d4ab3 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java @@ -0,0 +1,21 @@ +package io.github.amithkoujalgi.ollama4j.core.utils; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; + +public class BooleanToJsonFormatFlagSerializer extends JsonSerializer{ + + @Override + public void serialize(Boolean value, JsonGenerator gen, SerializerProvider serializers) throws IOException { + gen.writeString("json"); + } + + @Override + public boolean isEmpty(SerializerProvider provider,Boolean value){ + return !value; + } + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java index 680635b..8e862ab 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/FileToBase64Serializer.java @@ -1,8 +1,6 @@ package io.github.amithkoujalgi.ollama4j.core.utils; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.ObjectOutputStream; import java.util.Base64; import java.util.Collection; @@ -20,11 +18,4 @@ public class FileToBase64Serializer extends JsonSerializer> { } jsonGenerator.writeEndArray(); } - - public static byte[] serialize(Object obj) throws IOException { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - ObjectOutputStream os = new ObjectOutputStream(out); - os.writeObject(obj); - return out.toByteArray(); - } } \ No newline at end of file diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java new file mode 100644 index 0000000..f5fa5c9 --- /dev/null +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java @@ -0,0 +1,106 @@ +package io.github.amithkoujalgi.ollama4j.unittests.jackson; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import java.io.File; +import java.util.List; + +import org.json.JSONObject; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; +import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; + +public class TestChatRequestSerialization { + + private OllamaChatRequestBuilder builder; + + private ObjectMapper mapper = Utils.getObjectMapper(); + + @BeforeEach + public void init() { + builder = OllamaChatRequestBuilder.getInstance("DummyModel"); + } + + @Test + public void testRequestOnlyMandatoryFields() { + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", + List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build(); + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + } + + @Test + public void testRequestMultipleMessages() { + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt") + .withMessage(OllamaChatMessageRole.USER, "Some prompt") + .build(); + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + } + + @Test + public void testRequestWithMessageAndImage() { + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", + List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build(); + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + } + + @Test + public void testRequestWithOptions() { + OptionsBuilder b = new OptionsBuilder(); + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt") + .withOptions(b.setMirostat(1).build()).build(); + + String jsonRequest = serializeRequest(req); + OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest); + assertEqualsAfterUnmarshalling(deserializeRequest, req); + assertEquals(1, deserializeRequest.getOptions().get("mirostat")); + } + + @Test + public void testWithJsonFormat() { + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt") + .withGetJsonResponse().build(); + + String jsonRequest = serializeRequest(req); + // no jackson deserialization as format property is not boolean ==> omit as deserialization + // of request is never used in real code anyways + JSONObject jsonObject = new JSONObject(jsonRequest); + String requestFormatProperty = jsonObject.getString("format"); + assertEquals("json", requestFormatProperty); + } + + private String serializeRequest(OllamaChatRequestModel req) { + try { + return mapper.writeValueAsString(req); + } catch (JsonProcessingException e) { + fail("Could not serialize request!", e); + return null; + } + } + + private OllamaChatRequestModel deserializeRequest(String jsonRequest) { + try { + return mapper.readValue(jsonRequest, OllamaChatRequestModel.class); + } catch (JsonProcessingException e) { + fail("Could not deserialize jsonRequest!", e); + return null; + } + } + + private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest, + OllamaChatRequestModel req) { + assertEquals(req, unmarshalledRequest); + } + +} diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java new file mode 100644 index 0000000..7cf0513 --- /dev/null +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java @@ -0,0 +1,85 @@ +package io.github.amithkoujalgi.ollama4j.unittests.jackson; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import org.json.JSONObject; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; +import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; + +public class TestGenerateRequestSerialization { + + private OllamaGenerateRequestBuilder builder; + + private ObjectMapper mapper = Utils.getObjectMapper(); + + @BeforeEach + public void init() { + builder = OllamaGenerateRequestBuilder.getInstance("DummyModel"); + } + + @Test + public void testRequestOnlyMandatoryFields() { + OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build(); + + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + } + + @Test + public void testRequestWithOptions() { + OptionsBuilder b = new OptionsBuilder(); + OllamaGenerateRequestModel req = + builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build(); + + String jsonRequest = serializeRequest(req); + OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest); + assertEqualsAfterUnmarshalling(deserializeRequest, req); + assertEquals(1, deserializeRequest.getOptions().get("mirostat")); + } + + @Test + public void testWithJsonFormat() { + OllamaGenerateRequestModel req = + builder.withPrompt("Some prompt").withGetJsonResponse().build(); + + String jsonRequest = serializeRequest(req); + // no jackson deserialization as format property is not boolean ==> omit as deserialization + // of request is never used in real code anyways + JSONObject jsonObject = new JSONObject(jsonRequest); + String requestFormatProperty = jsonObject.getString("format"); + assertEquals("json", requestFormatProperty); + } + + private String serializeRequest(OllamaGenerateRequestModel req) { + try { + return mapper.writeValueAsString(req); + } catch (JsonProcessingException e) { + fail("Could not serialize request!", e); + return null; + } + } + + private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) { + try { + return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class); + } catch (JsonProcessingException e) { + fail("Could not deserialize jsonRequest!", e); + return null; + } + } + + private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest, + OllamaGenerateRequestModel req) { + assertEquals(req, unmarshalledRequest); + } + +}