Merge pull request #28 from AgentSchmecker/feature/advanced_params_for_generate

Adds advanced request parameters to chat/generate requests
This commit is contained in:
Amith Koujalgi 2024-02-18 10:51:04 +05:30 committed by GitHub
commit 6487756764
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 395 additions and 94 deletions

View File

@ -174,6 +174,12 @@
<version>4.1.0</version> <version>4.1.0</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20240205</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<distributionManagement> <distributionManagement>

View File

@ -6,6 +6,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest;
@ -345,7 +346,7 @@ public class OllamaAPI {
*/ */
public OllamaResult generate(String model, String prompt, Options options) public OllamaResult generate(String model, String prompt, Options options)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel); return generateSyncForOllamaRequestModel(ollamaRequestModel);
} }
@ -360,7 +361,7 @@ public class OllamaAPI {
* @return the ollama async result callback handle * @return the ollama async result callback handle
*/ */
public OllamaAsyncResultCallback generateAsync(String model, String prompt) { public OllamaAsyncResultCallback generateAsync(String model, String prompt) {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback = OllamaAsyncResultCallback ollamaAsyncResultCallback =
@ -389,7 +390,7 @@ public class OllamaAPI {
for (File imageFile : imageFiles) { for (File imageFile : imageFiles) {
images.add(encodeFileToBase64(imageFile)); images.add(encodeFileToBase64(imageFile));
} }
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel); return generateSyncForOllamaRequestModel(ollamaRequestModel);
} }
@ -413,7 +414,7 @@ public class OllamaAPI {
for (String imageURL : imageURLs) { for (String imageURL : imageURLs) {
images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL))); images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
} }
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel); return generateSyncForOllamaRequestModel(ollamaRequestModel);
} }
@ -448,7 +449,7 @@ public class OllamaAPI {
* @throws InterruptedException in case the server is not reachable or network issues happen * @throws InterruptedException in case the server is not reachable or network issues happen
*/ */
public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{ public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{
return chat(request); return chat(request,null);
} }
/** /**
@ -486,7 +487,7 @@ public class OllamaAPI {
return Base64.getEncoder().encodeToString(bytes); return Base64.getEncoder().encodeToString(bytes);
} }
private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel) private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
return requestCaller.callSync(ollamaRequestModel); return requestCaller.callSync(ollamaRequestModel);

View File

@ -1,6 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.models; package io.github.amithkoujalgi.ollama4j.core.models;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
@ -22,7 +24,7 @@ import lombok.Getter;
@SuppressWarnings("unused") @SuppressWarnings("unused")
public class OllamaAsyncResultCallback extends Thread { public class OllamaAsyncResultCallback extends Thread {
private final HttpRequest.Builder requestBuilder; private final HttpRequest.Builder requestBuilder;
private final OllamaRequestModel ollamaRequestModel; private final OllamaGenerateRequestModel ollamaRequestModel;
private final Queue<String> queue = new LinkedList<>(); private final Queue<String> queue = new LinkedList<>();
private String result; private String result;
private boolean isDone; private boolean isDone;
@ -47,7 +49,7 @@ public class OllamaAsyncResultCallback extends Thread {
public OllamaAsyncResultCallback( public OllamaAsyncResultCallback(
HttpRequest.Builder requestBuilder, HttpRequest.Builder requestBuilder,
OllamaRequestModel ollamaRequestModel, OllamaGenerateRequestModel ollamaRequestModel,
long requestTimeoutSeconds) { long requestTimeoutSeconds) {
this.requestBuilder = requestBuilder; this.requestBuilder = requestBuilder;
this.ollamaRequestModel = ollamaRequestModel; this.ollamaRequestModel = ollamaRequestModel;
@ -87,8 +89,8 @@ public class OllamaAsyncResultCallback extends Thread {
queue.add(ollamaResponseModel.getError()); queue.add(ollamaResponseModel.getError());
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else { } else {
OllamaResponseModel ollamaResponseModel = OllamaGenerateResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
queue.add(ollamaResponseModel.getResponse()); queue.add(ollamaResponseModel.getResponse());
if (!ollamaResponseModel.isDone()) { if (!ollamaResponseModel.isDone()) {
responseBuffer.append(ollamaResponseModel.getResponse()); responseBuffer.append(ollamaResponseModel.getResponse());

View File

@ -0,0 +1,35 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import java.util.Map;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.github.amithkoujalgi.ollama4j.core.utils.BooleanToJsonFormatFlagSerializer;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import lombok.Data;
@Data
@JsonInclude(JsonInclude.Include.NON_NULL)
public abstract class OllamaCommonRequestModel {
protected String model;
@JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class)
@JsonProperty(value = "format")
protected Boolean returnFormatJson;
protected Map<String, Object> options;
protected String template;
protected boolean stream;
@JsonProperty(value = "keep_alive")
protected String keepAlive;
public String toString() {
try {
return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -1,39 +0,0 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import java.util.List;
import java.util.Map;
import lombok.Data;
@Data
public class OllamaRequestModel implements OllamaRequestBody{
private String model;
private String prompt;
private List<String> images;
private Map<String, Object> options;
public OllamaRequestModel(String model, String prompt) {
this.model = model;
this.prompt = prompt;
}
public OllamaRequestModel(String model, String prompt, List<String> images) {
this.model = model;
this.prompt = prompt;
this.images = images;
}
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -83,12 +83,12 @@ public class OllamaChatRequestBuilder {
} }
public OllamaChatRequestBuilder withOptions(Options options){ public OllamaChatRequestBuilder withOptions(Options options){
this.request.setOptions(options); this.request.setOptions(options.getOptionsMap());
return this; return this;
} }
public OllamaChatRequestBuilder withFormat(String format){ public OllamaChatRequestBuilder withGetJsonResponse(){
this.request.setFormat(format); this.request.setReturnFormatJson(true);
return this; return this;
} }

View File

@ -1,47 +1,39 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat; package io.github.amithkoujalgi.ollama4j.core.models.chat;
import java.util.List; import java.util.List;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; import lombok.Getter;
import lombok.Setter;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
/** /**
* Defines a Request to use against the ollama /api/chat endpoint. * Defines a Request to use against the ollama /api/chat endpoint.
* *
* @see <a * @see <a href=
* href="https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate * "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate
* Chat Completion</a> * Chat Completion</a>
*/ */
@Data @Getter
@AllArgsConstructor @Setter
@RequiredArgsConstructor public class OllamaChatRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody {
public class OllamaChatRequestModel implements OllamaRequestBody {
@NonNull private String model; private List<OllamaChatMessage> messages;
@NonNull private List<OllamaChatMessage> messages; public OllamaChatRequestModel() {}
private String format; public OllamaChatRequestModel(String model, List<OllamaChatMessage> messages) {
private Options options; this.model = model;
private String template; this.messages = messages;
private boolean stream; }
private String keepAlive;
@Override @Override
public String toString() { public boolean equals(Object o) {
try { if (!(o instanceof OllamaChatRequestModel)) {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); return false;
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
} }
return this.toString().equals(o.toString());
} }
} }

View File

@ -0,0 +1,55 @@
package io.github.amithkoujalgi.ollama4j.core.models.generate;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
/**
* Helper class for creating {@link io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel}
* objects using the builder-pattern.
*/
public class OllamaGenerateRequestBuilder {
private OllamaGenerateRequestBuilder(String model, String prompt){
request = new OllamaGenerateRequestModel(model, prompt);
}
private OllamaGenerateRequestModel request;
public static OllamaGenerateRequestBuilder getInstance(String model){
return new OllamaGenerateRequestBuilder(model,"");
}
public OllamaGenerateRequestModel build(){
return request;
}
public OllamaGenerateRequestBuilder withPrompt(String prompt){
request.setPrompt(prompt);
return this;
}
public OllamaGenerateRequestBuilder withGetJsonResponse(){
this.request.setReturnFormatJson(true);
return this;
}
public OllamaGenerateRequestBuilder withOptions(Options options){
this.request.setOptions(options.getOptionsMap());
return this;
}
public OllamaGenerateRequestBuilder withTemplate(String template){
this.request.setTemplate(template);
return this;
}
public OllamaGenerateRequestBuilder withStreaming(){
this.request.setStream(true);
return this;
}
public OllamaGenerateRequestBuilder withKeepAlive(String keepAlive){
this.request.setKeepAlive(keepAlive);
return this;
}
}

View File

@ -0,0 +1,46 @@
package io.github.amithkoujalgi.ollama4j.core.models.generate;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import java.util.List;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
public class OllamaGenerateRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody{
private String prompt;
private List<String> images;
private String system;
private String context;
private boolean raw;
public OllamaGenerateRequestModel() {
}
public OllamaGenerateRequestModel(String model, String prompt) {
this.model = model;
this.prompt = prompt;
}
public OllamaGenerateRequestModel(String model, String prompt, List<String> images) {
this.model = model;
this.prompt = prompt;
this.images = images;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof OllamaGenerateRequestModel)) {
return false;
}
return this.toString().equals(o.toString());
}
}

View File

@ -1,4 +1,4 @@
package io.github.amithkoujalgi.ollama4j.core.models; package io.github.amithkoujalgi.ollama4j.core.models.generate;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
@ -8,7 +8,7 @@ import lombok.Data;
@Data @Data
@JsonIgnoreProperties(ignoreUnknown = true) @JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaResponseModel { public class OllamaGenerateResponseModel {
private String model; private String model;
private @JsonProperty("created_at") String createdAt; private @JsonProperty("created_at") String createdAt;
private String response; private String response;

View File

@ -6,7 +6,7 @@ import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResponseModel; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
@ -25,7 +25,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
@Override @Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
try { try {
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
responseBuffer.append(ollamaResponseModel.getResponse()); responseBuffer.append(ollamaResponseModel.getResponse());
return ollamaResponseModel.isDone(); return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {

View File

@ -0,0 +1,21 @@
package io.github.amithkoujalgi.ollama4j.core.utils;
import java.io.IOException;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
public class BooleanToJsonFormatFlagSerializer extends JsonSerializer<Boolean>{
@Override
public void serialize(Boolean value, JsonGenerator gen, SerializerProvider serializers) throws IOException {
gen.writeString("json");
}
@Override
public boolean isEmpty(SerializerProvider provider,Boolean value){
return !value;
}
}

View File

@ -1,8 +1,6 @@
package io.github.amithkoujalgi.ollama4j.core.utils; package io.github.amithkoujalgi.ollama4j.core.utils;
import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.Base64; import java.util.Base64;
import java.util.Collection; import java.util.Collection;
@ -20,11 +18,4 @@ public class FileToBase64Serializer extends JsonSerializer<Collection<byte[]>> {
} }
jsonGenerator.writeEndArray(); jsonGenerator.writeEndArray();
} }
public static byte[] serialize(Object obj) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
ObjectOutputStream os = new ObjectOutputStream(out);
os.writeObject(obj);
return out.toByteArray();
}
} }

View File

@ -0,0 +1,106 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import java.io.File;
import java.util.List;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class TestChatRequestSerialization {
private OllamaChatRequestBuilder builder;
private ObjectMapper mapper = Utils.getObjectMapper();
@BeforeEach
public void init() {
builder = OllamaChatRequestBuilder.getInstance("DummyModel");
}
@Test
public void testRequestOnlyMandatoryFields() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
}
@Test
public void testRequestMultipleMessages() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt")
.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
}
@Test
public void testRequestWithMessageAndImage() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
}
@Test
public void testRequestWithOptions() {
OptionsBuilder b = new OptionsBuilder();
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req);
OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
}
@Test
public void testWithJsonFormat() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.withGetJsonResponse().build();
String jsonRequest = serializeRequest(req);
// no jackson deserialization as format property is not boolean ==> omit as deserialization
// of request is never used in real code anyways
JSONObject jsonObject = new JSONObject(jsonRequest);
String requestFormatProperty = jsonObject.getString("format");
assertEquals("json", requestFormatProperty);
}
private String serializeRequest(OllamaChatRequestModel req) {
try {
return mapper.writeValueAsString(req);
} catch (JsonProcessingException e) {
fail("Could not serialize request!", e);
return null;
}
}
private OllamaChatRequestModel deserializeRequest(String jsonRequest) {
try {
return mapper.readValue(jsonRequest, OllamaChatRequestModel.class);
} catch (JsonProcessingException e) {
fail("Could not deserialize jsonRequest!", e);
return null;
}
}
private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest,
OllamaChatRequestModel req) {
assertEquals(req, unmarshalledRequest);
}
}

View File

@ -0,0 +1,85 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class TestGenerateRequestSerialization {
private OllamaGenerateRequestBuilder builder;
private ObjectMapper mapper = Utils.getObjectMapper();
@BeforeEach
public void init() {
builder = OllamaGenerateRequestBuilder.getInstance("DummyModel");
}
@Test
public void testRequestOnlyMandatoryFields() {
OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
}
@Test
public void testRequestWithOptions() {
OptionsBuilder b = new OptionsBuilder();
OllamaGenerateRequestModel req =
builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req);
OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
}
@Test
public void testWithJsonFormat() {
OllamaGenerateRequestModel req =
builder.withPrompt("Some prompt").withGetJsonResponse().build();
String jsonRequest = serializeRequest(req);
// no jackson deserialization as format property is not boolean ==> omit as deserialization
// of request is never used in real code anyways
JSONObject jsonObject = new JSONObject(jsonRequest);
String requestFormatProperty = jsonObject.getString("format");
assertEquals("json", requestFormatProperty);
}
private String serializeRequest(OllamaGenerateRequestModel req) {
try {
return mapper.writeValueAsString(req);
} catch (JsonProcessingException e) {
fail("Could not serialize request!", e);
return null;
}
}
private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) {
try {
return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class);
} catch (JsonProcessingException e) {
fail("Could not deserialize jsonRequest!", e);
return null;
}
}
private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest,
OllamaGenerateRequestModel req) {
assertEquals(req, unmarshalledRequest);
}
}