diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java index f554638..6f985ab 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java @@ -1,26 +1,35 @@ package io.github.amithkoujalgi.ollama4j.core.models; import java.util.Map; - +import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import io.github.amithkoujalgi.ollama4j.core.utils.BooleanToJsonFormatFlagSerializer; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import lombok.Data; -import lombok.NonNull; -import lombok.RequiredArgsConstructor; @Data -@RequiredArgsConstructor +@JsonInclude(JsonInclude.Include.NON_NULL) public abstract class OllamaCommonRequestModel { - @NonNull protected String model; @JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class) - protected boolean returnFormatJson; + @JsonProperty(value = "format") + protected Boolean returnFormatJson; protected Map options; protected String template; protected boolean stream; @JsonProperty(value = "keep_alive") protected String keepAlive; + + + public String toString() { + try { + return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java index 82c9010..e55bf6a 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java @@ -1,42 +1,39 @@ package io.github.amithkoujalgi.ollama4j.core.models.chat; import java.util.List; - -import com.fasterxml.jackson.core.JsonProcessingException; - import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; -import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NonNull; +import lombok.Getter; +import lombok.Setter; /** * Defines a Request to use against the ollama /api/chat endpoint. * - * @see Generate - * Chat Completion + * @see Generate + * Chat Completion */ -@Data -@EqualsAndHashCode(callSuper = true) +@Getter +@Setter public class OllamaChatRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody { - @NonNull private List messages; + private List messages; - public OllamaChatRequestModel(String model,List messages){ - super(model); + public OllamaChatRequestModel() {} + + public OllamaChatRequestModel(String model, List messages) { + this.model = model; this.messages = messages; } @Override - public String toString() { - try { - return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); + public boolean equals(Object o) { + if (!(o instanceof OllamaChatRequestModel)) { + return false; } + + return this.toString().equals(o.toString()); } + } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java index 861be1a..b060a4c 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java @@ -1,22 +1,18 @@ package io.github.amithkoujalgi.ollama4j.core.models.generate; -import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; - -import com.fasterxml.jackson.core.JsonProcessingException; import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import java.util.List; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NonNull; -@Data -@EqualsAndHashCode(callSuper = true) +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter public class OllamaGenerateRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody{ - @NonNull private String prompt; private List images; @@ -24,23 +20,27 @@ public class OllamaGenerateRequestModel extends OllamaCommonRequestModel impleme private String context; private boolean raw; + public OllamaGenerateRequestModel() { + } public OllamaGenerateRequestModel(String model, String prompt) { - super(model); + this.model = model; this.prompt = prompt; } public OllamaGenerateRequestModel(String model, String prompt, List images) { - super(model); + this.model = model; this.prompt = prompt; this.images = images; } - public String toString() { - try { - return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); + @Override + public boolean equals(Object o) { + if (!(o instanceof OllamaGenerateRequestModel)) { + return false; } + + return this.toString().equals(o.toString()); } + } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java new file mode 100644 index 0000000..f5fa5c9 --- /dev/null +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java @@ -0,0 +1,106 @@ +package io.github.amithkoujalgi.ollama4j.unittests.jackson; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import java.io.File; +import java.util.List; + +import org.json.JSONObject; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; +import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; + +public class TestChatRequestSerialization { + + private OllamaChatRequestBuilder builder; + + private ObjectMapper mapper = Utils.getObjectMapper(); + + @BeforeEach + public void init() { + builder = OllamaChatRequestBuilder.getInstance("DummyModel"); + } + + @Test + public void testRequestOnlyMandatoryFields() { + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", + List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build(); + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + } + + @Test + public void testRequestMultipleMessages() { + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt") + .withMessage(OllamaChatMessageRole.USER, "Some prompt") + .build(); + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + } + + @Test + public void testRequestWithMessageAndImage() { + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", + List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build(); + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + } + + @Test + public void testRequestWithOptions() { + OptionsBuilder b = new OptionsBuilder(); + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt") + .withOptions(b.setMirostat(1).build()).build(); + + String jsonRequest = serializeRequest(req); + OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest); + assertEqualsAfterUnmarshalling(deserializeRequest, req); + assertEquals(1, deserializeRequest.getOptions().get("mirostat")); + } + + @Test + public void testWithJsonFormat() { + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt") + .withGetJsonResponse().build(); + + String jsonRequest = serializeRequest(req); + // no jackson deserialization as format property is not boolean ==> omit as deserialization + // of request is never used in real code anyways + JSONObject jsonObject = new JSONObject(jsonRequest); + String requestFormatProperty = jsonObject.getString("format"); + assertEquals("json", requestFormatProperty); + } + + private String serializeRequest(OllamaChatRequestModel req) { + try { + return mapper.writeValueAsString(req); + } catch (JsonProcessingException e) { + fail("Could not serialize request!", e); + return null; + } + } + + private OllamaChatRequestModel deserializeRequest(String jsonRequest) { + try { + return mapper.readValue(jsonRequest, OllamaChatRequestModel.class); + } catch (JsonProcessingException e) { + fail("Could not deserialize jsonRequest!", e); + return null; + } + } + + private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest, + OllamaChatRequestModel req) { + assertEquals(req, unmarshalledRequest); + } + +} diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java new file mode 100644 index 0000000..7cf0513 --- /dev/null +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java @@ -0,0 +1,85 @@ +package io.github.amithkoujalgi.ollama4j.unittests.jackson; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import org.json.JSONObject; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; +import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; + +public class TestGenerateRequestSerialization { + + private OllamaGenerateRequestBuilder builder; + + private ObjectMapper mapper = Utils.getObjectMapper(); + + @BeforeEach + public void init() { + builder = OllamaGenerateRequestBuilder.getInstance("DummyModel"); + } + + @Test + public void testRequestOnlyMandatoryFields() { + OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build(); + + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + } + + @Test + public void testRequestWithOptions() { + OptionsBuilder b = new OptionsBuilder(); + OllamaGenerateRequestModel req = + builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build(); + + String jsonRequest = serializeRequest(req); + OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest); + assertEqualsAfterUnmarshalling(deserializeRequest, req); + assertEquals(1, deserializeRequest.getOptions().get("mirostat")); + } + + @Test + public void testWithJsonFormat() { + OllamaGenerateRequestModel req = + builder.withPrompt("Some prompt").withGetJsonResponse().build(); + + String jsonRequest = serializeRequest(req); + // no jackson deserialization as format property is not boolean ==> omit as deserialization + // of request is never used in real code anyways + JSONObject jsonObject = new JSONObject(jsonRequest); + String requestFormatProperty = jsonObject.getString("format"); + assertEquals("json", requestFormatProperty); + } + + private String serializeRequest(OllamaGenerateRequestModel req) { + try { + return mapper.writeValueAsString(req); + } catch (JsonProcessingException e) { + fail("Could not serialize request!", e); + return null; + } + } + + private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) { + try { + return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class); + } catch (JsonProcessingException e) { + fail("Could not deserialize jsonRequest!", e); + return null; + } + } + + private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest, + OllamaGenerateRequestModel req) { + assertEquals(req, unmarshalledRequest); + } + +}