Adds additional properties to chat and generate requests

This commit is contained in:
Markus Klenke 2024-02-16 15:57:48 +00:00
parent 91aab6cbd1
commit 06c5daa253
5 changed files with 240 additions and 43 deletions

View File

@ -1,26 +1,35 @@
package io.github.amithkoujalgi.ollama4j.core.models; package io.github.amithkoujalgi.ollama4j.core.models;
import java.util.Map; import java.util.Map;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.github.amithkoujalgi.ollama4j.core.utils.BooleanToJsonFormatFlagSerializer; import io.github.amithkoujalgi.ollama4j.core.utils.BooleanToJsonFormatFlagSerializer;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import lombok.Data; import lombok.Data;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
@Data @Data
@RequiredArgsConstructor @JsonInclude(JsonInclude.Include.NON_NULL)
public abstract class OllamaCommonRequestModel { public abstract class OllamaCommonRequestModel {
@NonNull
protected String model; protected String model;
@JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class) @JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class)
protected boolean returnFormatJson; @JsonProperty(value = "format")
protected Boolean returnFormatJson;
protected Map<String, Object> options; protected Map<String, Object> options;
protected String template; protected String template;
protected boolean stream; protected boolean stream;
@JsonProperty(value = "keep_alive") @JsonProperty(value = "keep_alive")
protected String keepAlive; protected String keepAlive;
public String toString() {
try {
return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
} }

View File

@ -1,42 +1,39 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat; package io.github.amithkoujalgi.ollama4j.core.models.chat;
import java.util.List; import java.util.List;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; import lombok.Getter;
import lombok.Setter;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
/** /**
* Defines a Request to use against the ollama /api/chat endpoint. * Defines a Request to use against the ollama /api/chat endpoint.
* *
* @see <a * @see <a href=
* href="https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate * "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate
* Chat Completion</a> * Chat Completion</a>
*/ */
@Data @Getter
@EqualsAndHashCode(callSuper = true) @Setter
public class OllamaChatRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody { public class OllamaChatRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody {
@NonNull private List<OllamaChatMessage> messages; private List<OllamaChatMessage> messages;
public OllamaChatRequestModel() {}
public OllamaChatRequestModel(String model, List<OllamaChatMessage> messages) { public OllamaChatRequestModel(String model, List<OllamaChatMessage> messages) {
super(model); this.model = model;
this.messages = messages; this.messages = messages;
} }
@Override @Override
public String toString() { public boolean equals(Object o) {
try { if (!(o instanceof OllamaChatRequestModel)) {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); return false;
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
} }
return this.toString().equals(o.toString());
} }
} }

View File

@ -1,22 +1,18 @@
package io.github.amithkoujalgi.ollama4j.core.models.generate; package io.github.amithkoujalgi.ollama4j.core.models.generate;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import java.util.List; import java.util.List;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
@Data import lombok.Getter;
@EqualsAndHashCode(callSuper = true) import lombok.Setter;
@Getter
@Setter
public class OllamaGenerateRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody{ public class OllamaGenerateRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody{
@NonNull
private String prompt; private String prompt;
private List<String> images; private List<String> images;
@ -24,23 +20,27 @@ public class OllamaGenerateRequestModel extends OllamaCommonRequestModel impleme
private String context; private String context;
private boolean raw; private boolean raw;
public OllamaGenerateRequestModel() {
}
public OllamaGenerateRequestModel(String model, String prompt) { public OllamaGenerateRequestModel(String model, String prompt) {
super(model); this.model = model;
this.prompt = prompt; this.prompt = prompt;
} }
public OllamaGenerateRequestModel(String model, String prompt, List<String> images) { public OllamaGenerateRequestModel(String model, String prompt, List<String> images) {
super(model); this.model = model;
this.prompt = prompt; this.prompt = prompt;
this.images = images; this.images = images;
} }
public String toString() { @Override
try { public boolean equals(Object o) {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); if (!(o instanceof OllamaGenerateRequestModel)) {
} catch (JsonProcessingException e) { return false;
throw new RuntimeException(e);
} }
return this.toString().equals(o.toString());
} }
} }

View File

@ -0,0 +1,106 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import java.io.File;
import java.util.List;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class TestChatRequestSerialization {
private OllamaChatRequestBuilder builder;
private ObjectMapper mapper = Utils.getObjectMapper();
@BeforeEach
public void init() {
builder = OllamaChatRequestBuilder.getInstance("DummyModel");
}
@Test
public void testRequestOnlyMandatoryFields() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
}
@Test
public void testRequestMultipleMessages() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt")
.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
}
@Test
public void testRequestWithMessageAndImage() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
}
@Test
public void testRequestWithOptions() {
OptionsBuilder b = new OptionsBuilder();
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req);
OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
}
@Test
public void testWithJsonFormat() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.withGetJsonResponse().build();
String jsonRequest = serializeRequest(req);
// no jackson deserialization as format property is not boolean ==> omit as deserialization
// of request is never used in real code anyways
JSONObject jsonObject = new JSONObject(jsonRequest);
String requestFormatProperty = jsonObject.getString("format");
assertEquals("json", requestFormatProperty);
}
private String serializeRequest(OllamaChatRequestModel req) {
try {
return mapper.writeValueAsString(req);
} catch (JsonProcessingException e) {
fail("Could not serialize request!", e);
return null;
}
}
private OllamaChatRequestModel deserializeRequest(String jsonRequest) {
try {
return mapper.readValue(jsonRequest, OllamaChatRequestModel.class);
} catch (JsonProcessingException e) {
fail("Could not deserialize jsonRequest!", e);
return null;
}
}
private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest,
OllamaChatRequestModel req) {
assertEquals(req, unmarshalledRequest);
}
}

View File

@ -0,0 +1,85 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class TestGenerateRequestSerialization {
private OllamaGenerateRequestBuilder builder;
private ObjectMapper mapper = Utils.getObjectMapper();
@BeforeEach
public void init() {
builder = OllamaGenerateRequestBuilder.getInstance("DummyModel");
}
@Test
public void testRequestOnlyMandatoryFields() {
OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
}
@Test
public void testRequestWithOptions() {
OptionsBuilder b = new OptionsBuilder();
OllamaGenerateRequestModel req =
builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req);
OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
}
@Test
public void testWithJsonFormat() {
OllamaGenerateRequestModel req =
builder.withPrompt("Some prompt").withGetJsonResponse().build();
String jsonRequest = serializeRequest(req);
// no jackson deserialization as format property is not boolean ==> omit as deserialization
// of request is never used in real code anyways
JSONObject jsonObject = new JSONObject(jsonRequest);
String requestFormatProperty = jsonObject.getString("format");
assertEquals("json", requestFormatProperty);
}
private String serializeRequest(OllamaGenerateRequestModel req) {
try {
return mapper.writeValueAsString(req);
} catch (JsonProcessingException e) {
fail("Could not serialize request!", e);
return null;
}
}
private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) {
try {
return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class);
} catch (JsonProcessingException e) {
fail("Could not deserialize jsonRequest!", e);
return null;
}
}
private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest,
OllamaGenerateRequestModel req) {
assertEquals(req, unmarshalledRequest);
}
}