diff --git a/README.md b/README.md
index 42270f3..6cdfc25 100644
--- a/README.md
+++ b/README.md
@@ -67,7 +67,7 @@ In your Maven project, add this dependency:
io.github.amithkoujalgi
ollama4j
- 1.0.47
+ 1.0.57
```
@@ -125,15 +125,15 @@ Actions CI workflow.
- [x] Update request body creation with Java objects
- [ ] Async APIs for images
- [ ] Add custom headers to requests
-- [ ] Add additional params for `ask` APIs such as:
+- [x] Add additional params for `ask` APIs such as:
- [x] `options`: additional model parameters for the Modelfile such as `temperature` -
Supported [params](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
- - [ ] `system`: system prompt to (overrides what is defined in the Modelfile)
- - [ ] `template`: the full prompt or prompt template (overrides what is defined in the Modelfile)
- - [ ] `context`: the context parameter returned from a previous request, which can be used to keep a
+ - [x] `system`: system prompt to (overrides what is defined in the Modelfile)
+ - [x] `template`: the full prompt or prompt template (overrides what is defined in the Modelfile)
+ - [x] `context`: the context parameter returned from a previous request, which can be used to keep a
short
conversational memory
- - [ ] `stream`: Add support for streaming responses from the model
+ - [x] `stream`: Add support for streaming responses from the model
- [ ] Add test cases
- [ ] Handle exceptions better (maybe throw more appropriate exceptions)
diff --git a/pom.xml b/pom.xml
index d375a54..496c817 100644
--- a/pom.xml
+++ b/pom.xml
@@ -99,7 +99,7 @@
${skipUnitTests}
- **/unittests/*.java
+ **/unittests/**/*.java
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
index ec772f1..25b3a37 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
@@ -6,10 +6,11 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
+import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingResponseModel;
+import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
-import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaChatEndpointCaller;
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateEndpointCaller;
@@ -313,8 +314,18 @@ public class OllamaAPI {
*/
public List generateEmbeddings(String model, String prompt)
throws IOException, InterruptedException, OllamaBaseException {
+ return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
+ }
+
+ /**
+ * Generate embeddings using a {@link OllamaEmbeddingsRequestModel}.
+ *
+ * @param modelRequest request for '/api/embeddings' endpoint
+ * @return embeddings
+ */
+ public List generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException{
URI uri = URI.create(this.host + "/api/embeddings");
- String jsonData = new ModelEmbeddingsRequest(model, prompt).toString();
+ String jsonData = modelRequest.toString();
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
@@ -325,8 +336,8 @@ public class OllamaAPI {
int statusCode = response.statusCode();
String responseBody = response.body();
if (statusCode == 200) {
- EmbeddingResponse embeddingResponse =
- Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class);
+ OllamaEmbeddingResponseModel embeddingResponse =
+ Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class);
return embeddingResponse.getEmbedding();
} else {
throw new OllamaBaseException(statusCode + " - " + responseBody);
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingResponseModel.java
similarity index 65%
rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java
rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingResponseModel.java
index e3040a2..85dba31 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingResponseModel.java
@@ -1,4 +1,4 @@
-package io.github.amithkoujalgi.ollama4j.core.models;
+package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
import com.fasterxml.jackson.annotation.JsonProperty;
@@ -7,7 +7,7 @@ import lombok.Data;
@SuppressWarnings("unused")
@Data
-public class EmbeddingResponse {
+public class OllamaEmbeddingResponseModel {
@JsonProperty("embedding")
private List embedding;
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestBuilder.java
new file mode 100644
index 0000000..ef7a84e
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestBuilder.java
@@ -0,0 +1,31 @@
+package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
+
+import io.github.amithkoujalgi.ollama4j.core.utils.Options;
+
+public class OllamaEmbeddingsRequestBuilder {
+
+ private OllamaEmbeddingsRequestBuilder(String model, String prompt){
+ request = new OllamaEmbeddingsRequestModel(model, prompt);
+ }
+
+ private OllamaEmbeddingsRequestModel request;
+
+ public static OllamaEmbeddingsRequestBuilder getInstance(String model, String prompt){
+ return new OllamaEmbeddingsRequestBuilder(model, prompt);
+ }
+
+ public OllamaEmbeddingsRequestModel build(){
+ return request;
+ }
+
+ public OllamaEmbeddingsRequestBuilder withOptions(Options options){
+ this.request.setOptions(options.getOptionsMap());
+ return this;
+ }
+
+ public OllamaEmbeddingsRequestBuilder withKeepAlive(String keepAlive){
+ this.request.setKeepAlive(keepAlive);
+ return this;
+ }
+
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestModel.java
new file mode 100644
index 0000000..a369124
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestModel.java
@@ -0,0 +1,33 @@
+package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
+
+import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
+import java.util.Map;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.NonNull;
+import lombok.RequiredArgsConstructor;
+
+@Data
+@RequiredArgsConstructor
+@NoArgsConstructor
+public class OllamaEmbeddingsRequestModel {
+ @NonNull
+ private String model;
+ @NonNull
+ private String prompt;
+
+ protected Map options;
+ @JsonProperty(value = "keep_alive")
+ private String keepAlive;
+
+ @Override
+ public String toString() {
+ try {
+ return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java
deleted file mode 100644
index 1455a94..0000000
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java
+++ /dev/null
@@ -1,23 +0,0 @@
-package io.github.amithkoujalgi.ollama4j.core.models.request;
-
-import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
-
-import com.fasterxml.jackson.core.JsonProcessingException;
-import lombok.AllArgsConstructor;
-import lombok.Data;
-
-@Data
-@AllArgsConstructor
-public class ModelEmbeddingsRequest {
- private String model;
- private String prompt;
-
- @Override
- public String toString() {
- try {
- return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
- } catch (JsonProcessingException e) {
- throw new RuntimeException(e);
- }
- }
-}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
index dc91287..d822077 100644
--- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
@@ -10,6 +10,8 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
+import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
+import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import java.io.File;
import java.io.IOException;
@@ -61,7 +63,7 @@ class TestRealAPIs {
} catch (HttpConnectTimeoutException e) {
fail(e.getMessage());
} catch (Exception e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -73,7 +75,7 @@ class TestRealAPIs {
assertNotNull(ollamaAPI.listModels());
ollamaAPI.listModels().forEach(System.out::println);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -88,7 +90,7 @@ class TestRealAPIs {
.anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
assertTrue(found);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -101,7 +103,7 @@ class TestRealAPIs {
assertNotNull(modelDetails);
System.out.println(modelDetails);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -119,7 +121,7 @@ class TestRealAPIs {
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -145,7 +147,7 @@ class TestRealAPIs {
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -163,7 +165,7 @@ class TestRealAPIs {
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -183,7 +185,7 @@ class TestRealAPIs {
assertFalse(chatResult.getResponse().isBlank());
assertEquals(4,chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -205,7 +207,7 @@ class TestRealAPIs {
assertTrue(chatResult.getResponse().startsWith("NI"));
assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -230,7 +232,7 @@ class TestRealAPIs {
assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -261,7 +263,7 @@ class TestRealAPIs {
} catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -278,7 +280,7 @@ class TestRealAPIs {
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
} catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -298,7 +300,7 @@ class TestRealAPIs {
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -322,7 +324,7 @@ class TestRealAPIs {
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ fail(e);
}
}
@@ -342,7 +344,24 @@ class TestRealAPIs {
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ fail(e);
+ }
+ }
+
+ @Test
+ @Order(3)
+ public void testEmbedding() {
+ testEndpointReachability();
+ try {
+ OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
+ .getInstance(config.getModel(), "What is the capital of France?").build();
+
+ List embeddings = ollamaAPI.generateEmbeddings(request);
+
+ assertNotNull(embeddings);
+ assertFalse(embeddings.isEmpty());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
}
}
}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/AbstractRequestSerializationTest.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/AbstractRequestSerializationTest.java
new file mode 100644
index 0000000..c6b2ff5
--- /dev/null
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/AbstractRequestSerializationTest.java
@@ -0,0 +1,35 @@
+package io.github.amithkoujalgi.ollama4j.unittests.jackson;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
+
+public abstract class AbstractRequestSerializationTest {
+
+ protected ObjectMapper mapper = Utils.getObjectMapper();
+
+ protected String serializeRequest(T req) {
+ try {
+ return mapper.writeValueAsString(req);
+ } catch (JsonProcessingException e) {
+ fail("Could not serialize request!", e);
+ return null;
+ }
+ }
+
+ protected T deserializeRequest(String jsonRequest, Class requestClass) {
+ try {
+ return mapper.readValue(jsonRequest, requestClass);
+ } catch (JsonProcessingException e) {
+ fail("Could not deserialize jsonRequest!", e);
+ return null;
+ }
+ }
+
+ protected void assertEqualsAfterUnmarshalling(T unmarshalledRequest,
+ T req) {
+ assertEquals(req, unmarshalledRequest);
+ }
+}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java
index f5fa5c9..c5a7060 100644
--- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java
@@ -1,7 +1,6 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.fail;
import java.io.File;
import java.util.List;
@@ -10,21 +9,15 @@ import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.ObjectMapper;
-
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
-import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
-public class TestChatRequestSerialization {
+public class TestChatRequestSerialization extends AbstractRequestSerializationTest{
private OllamaChatRequestBuilder builder;
- private ObjectMapper mapper = Utils.getObjectMapper();
-
@BeforeEach
public void init() {
builder = OllamaChatRequestBuilder.getInstance("DummyModel");
@@ -32,10 +25,9 @@ public class TestChatRequestSerialization {
@Test
public void testRequestOnlyMandatoryFields() {
- OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
- List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
+ OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt").build();
String jsonRequest = serializeRequest(req);
- assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
+ assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req);
}
@Test
@@ -44,7 +36,7 @@ public class TestChatRequestSerialization {
.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.build();
String jsonRequest = serializeRequest(req);
- assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
+ assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req);
}
@Test
@@ -52,7 +44,7 @@ public class TestChatRequestSerialization {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serializeRequest(req);
- assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
+ assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req);
}
@Test
@@ -62,7 +54,7 @@ public class TestChatRequestSerialization {
.withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req);
- OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest);
+ OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest,OllamaChatRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
}
@@ -79,28 +71,4 @@ public class TestChatRequestSerialization {
String requestFormatProperty = jsonObject.getString("format");
assertEquals("json", requestFormatProperty);
}
-
- private String serializeRequest(OllamaChatRequestModel req) {
- try {
- return mapper.writeValueAsString(req);
- } catch (JsonProcessingException e) {
- fail("Could not serialize request!", e);
- return null;
- }
- }
-
- private OllamaChatRequestModel deserializeRequest(String jsonRequest) {
- try {
- return mapper.readValue(jsonRequest, OllamaChatRequestModel.class);
- } catch (JsonProcessingException e) {
- fail("Could not deserialize jsonRequest!", e);
- return null;
- }
- }
-
- private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest,
- OllamaChatRequestModel req) {
- assertEquals(req, unmarshalledRequest);
- }
-
}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestEmbeddingsRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestEmbeddingsRequestSerialization.java
new file mode 100644
index 0000000..ff1e308
--- /dev/null
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestEmbeddingsRequestSerialization.java
@@ -0,0 +1,37 @@
+package io.github.amithkoujalgi.ollama4j.unittests.jackson;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
+import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
+import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
+
+public class TestEmbeddingsRequestSerialization extends AbstractRequestSerializationTest{
+
+ private OllamaEmbeddingsRequestBuilder builder;
+
+ @BeforeEach
+ public void init() {
+ builder = OllamaEmbeddingsRequestBuilder.getInstance("DummyModel","DummyPrompt");
+ }
+
+ @Test
+ public void testRequestOnlyMandatoryFields() {
+ OllamaEmbeddingsRequestModel req = builder.build();
+ String jsonRequest = serializeRequest(req);
+ assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaEmbeddingsRequestModel.class), req);
+ }
+
+ @Test
+ public void testRequestWithOptions() {
+ OptionsBuilder b = new OptionsBuilder();
+ OllamaEmbeddingsRequestModel req = builder
+ .withOptions(b.setMirostat(1).build()).build();
+
+ String jsonRequest = serializeRequest(req);
+ OllamaEmbeddingsRequestModel deserializeRequest = deserializeRequest(jsonRequest,OllamaEmbeddingsRequestModel.class);
+ assertEqualsAfterUnmarshalling(deserializeRequest, req);
+ assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
+ }
+}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java
index 7cf0513..03610f7 100644
--- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java
@@ -1,26 +1,20 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.fail;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
-import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
-public class TestGenerateRequestSerialization {
+public class TestGenerateRequestSerialization extends AbstractRequestSerializationTest{
private OllamaGenerateRequestBuilder builder;
- private ObjectMapper mapper = Utils.getObjectMapper();
-
@BeforeEach
public void init() {
builder = OllamaGenerateRequestBuilder.getInstance("DummyModel");
@@ -31,7 +25,7 @@ public class TestGenerateRequestSerialization {
OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build();
String jsonRequest = serializeRequest(req);
- assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
+ assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest, OllamaGenerateRequestModel.class), req);
}
@Test
@@ -41,7 +35,7 @@ public class TestGenerateRequestSerialization {
builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req);
- OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest);
+ OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest, OllamaGenerateRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
}
@@ -59,27 +53,4 @@ public class TestGenerateRequestSerialization {
assertEquals("json", requestFormatProperty);
}
- private String serializeRequest(OllamaGenerateRequestModel req) {
- try {
- return mapper.writeValueAsString(req);
- } catch (JsonProcessingException e) {
- fail("Could not serialize request!", e);
- return null;
- }
- }
-
- private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) {
- try {
- return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class);
- } catch (JsonProcessingException e) {
- fail("Could not deserialize jsonRequest!", e);
- return null;
- }
- }
-
- private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest,
- OllamaGenerateRequestModel req) {
- assertEquals(req, unmarshalledRequest);
- }
-
}