From 63d4de4e246e8aa6b0871fd7ca93b5424ff1d101 Mon Sep 17 00:00:00 2001 From: Markus Klenke Date: Sun, 25 Feb 2024 20:53:45 +0000 Subject: [PATCH] Adds options to EmbeddingsRequest Additionally refactors the Embedding Models and Tests --- README.md | 12 ++--- pom.xml | 2 +- .../ollama4j/core/OllamaAPI.java | 19 +++++-- .../OllamaEmbeddingResponseModel.java} | 4 +- .../OllamaEmbeddingsRequestBuilder.java | 31 ++++++++++++ .../OllamaEmbeddingsRequestModel.java | 33 +++++++++++++ .../request/ModelEmbeddingsRequest.java | 23 --------- .../integrationtests/TestRealAPIs.java | 49 +++++++++++++------ .../AbstractRequestSerializationTest.java | 35 +++++++++++++ .../jackson/TestChatRequestSerialization.java | 44 +++-------------- .../TestEmbeddingsRequestSerialization.java | 37 ++++++++++++++ .../TestGenerateRequestSerialization.java | 35 ++----------- 12 files changed, 203 insertions(+), 121 deletions(-) rename src/main/java/io/github/amithkoujalgi/ollama4j/core/models/{EmbeddingResponse.java => embeddings/OllamaEmbeddingResponseModel.java} (65%) create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestBuilder.java create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestModel.java delete mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java create mode 100644 src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/AbstractRequestSerializationTest.java create mode 100644 src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestEmbeddingsRequestSerialization.java diff --git a/README.md b/README.md index 42270f3..6cdfc25 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ In your Maven project, add this dependency: io.github.amithkoujalgi ollama4j - 1.0.47 + 1.0.57 ``` @@ -125,15 +125,15 @@ Actions CI workflow. - [x] Update request body creation with Java objects - [ ] Async APIs for images - [ ] Add custom headers to requests -- [ ] Add additional params for `ask` APIs such as: +- [x] Add additional params for `ask` APIs such as: - [x] `options`: additional model parameters for the Modelfile such as `temperature` - Supported [params](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). - - [ ] `system`: system prompt to (overrides what is defined in the Modelfile) - - [ ] `template`: the full prompt or prompt template (overrides what is defined in the Modelfile) - - [ ] `context`: the context parameter returned from a previous request, which can be used to keep a + - [x] `system`: system prompt to (overrides what is defined in the Modelfile) + - [x] `template`: the full prompt or prompt template (overrides what is defined in the Modelfile) + - [x] `context`: the context parameter returned from a previous request, which can be used to keep a short conversational memory - - [ ] `stream`: Add support for streaming responses from the model + - [x] `stream`: Add support for streaming responses from the model - [ ] Add test cases - [ ] Handle exceptions better (maybe throw more appropriate exceptions) diff --git a/pom.xml b/pom.xml index d375a54..496c817 100644 --- a/pom.xml +++ b/pom.xml @@ -99,7 +99,7 @@ ${skipUnitTests} - **/unittests/*.java + **/unittests/**/*.java diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index ec772f1..25b3a37 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -6,10 +6,11 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; +import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingResponseModel; +import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest; -import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaChatEndpointCaller; import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateEndpointCaller; @@ -313,8 +314,18 @@ public class OllamaAPI { */ public List generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException { + return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt)); + } + + /** + * Generate embeddings using a {@link OllamaEmbeddingsRequestModel}. + * + * @param modelRequest request for '/api/embeddings' endpoint + * @return embeddings + */ + public List generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException{ URI uri = URI.create(this.host + "/api/embeddings"); - String jsonData = new ModelEmbeddingsRequest(model, prompt).toString(); + String jsonData = modelRequest.toString(); HttpClient httpClient = HttpClient.newHttpClient(); HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri) @@ -325,8 +336,8 @@ public class OllamaAPI { int statusCode = response.statusCode(); String responseBody = response.body(); if (statusCode == 200) { - EmbeddingResponse embeddingResponse = - Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class); + OllamaEmbeddingResponseModel embeddingResponse = + Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class); return embeddingResponse.getEmbedding(); } else { throw new OllamaBaseException(statusCode + " - " + responseBody); diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingResponseModel.java similarity index 65% rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingResponseModel.java index e3040a2..85dba31 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingResponseModel.java @@ -1,4 +1,4 @@ -package io.github.amithkoujalgi.ollama4j.core.models; +package io.github.amithkoujalgi.ollama4j.core.models.embeddings; import com.fasterxml.jackson.annotation.JsonProperty; @@ -7,7 +7,7 @@ import lombok.Data; @SuppressWarnings("unused") @Data -public class EmbeddingResponse { +public class OllamaEmbeddingResponseModel { @JsonProperty("embedding") private List embedding; } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestBuilder.java new file mode 100644 index 0000000..ef7a84e --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestBuilder.java @@ -0,0 +1,31 @@ +package io.github.amithkoujalgi.ollama4j.core.models.embeddings; + +import io.github.amithkoujalgi.ollama4j.core.utils.Options; + +public class OllamaEmbeddingsRequestBuilder { + + private OllamaEmbeddingsRequestBuilder(String model, String prompt){ + request = new OllamaEmbeddingsRequestModel(model, prompt); + } + + private OllamaEmbeddingsRequestModel request; + + public static OllamaEmbeddingsRequestBuilder getInstance(String model, String prompt){ + return new OllamaEmbeddingsRequestBuilder(model, prompt); + } + + public OllamaEmbeddingsRequestModel build(){ + return request; + } + + public OllamaEmbeddingsRequestBuilder withOptions(Options options){ + this.request.setOptions(options.getOptionsMap()); + return this; + } + + public OllamaEmbeddingsRequestBuilder withKeepAlive(String keepAlive){ + this.request.setKeepAlive(keepAlive); + return this; + } + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestModel.java new file mode 100644 index 0000000..a369124 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestModel.java @@ -0,0 +1,33 @@ +package io.github.amithkoujalgi.ollama4j.core.models.embeddings; + +import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; +import java.util.Map; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; + +@Data +@RequiredArgsConstructor +@NoArgsConstructor +public class OllamaEmbeddingsRequestModel { + @NonNull + private String model; + @NonNull + private String prompt; + + protected Map options; + @JsonProperty(value = "keep_alive") + private String keepAlive; + + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java deleted file mode 100644 index 1455a94..0000000 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java +++ /dev/null @@ -1,23 +0,0 @@ -package io.github.amithkoujalgi.ollama4j.core.models.request; - -import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; - -import com.fasterxml.jackson.core.JsonProcessingException; -import lombok.AllArgsConstructor; -import lombok.Data; - -@Data -@AllArgsConstructor -public class ModelEmbeddingsRequest { - private String model; - private String prompt; - - @Override - public String toString() { - try { - return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } -} diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java index dc91287..d822077 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java @@ -10,6 +10,8 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; +import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel; +import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import java.io.File; import java.io.IOException; @@ -61,7 +63,7 @@ class TestRealAPIs { } catch (HttpConnectTimeoutException e) { fail(e.getMessage()); } catch (Exception e) { - throw new RuntimeException(e); + fail(e); } } @@ -73,7 +75,7 @@ class TestRealAPIs { assertNotNull(ollamaAPI.listModels()); ollamaAPI.listModels().forEach(System.out::println); } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - throw new RuntimeException(e); + fail(e); } } @@ -88,7 +90,7 @@ class TestRealAPIs { .anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel())); assertTrue(found); } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - throw new RuntimeException(e); + fail(e); } } @@ -101,7 +103,7 @@ class TestRealAPIs { assertNotNull(modelDetails); System.out.println(modelDetails); } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - throw new RuntimeException(e); + fail(e); } } @@ -119,7 +121,7 @@ class TestRealAPIs { assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + fail(e); } } @@ -145,7 +147,7 @@ class TestRealAPIs { assertFalse(result.getResponse().isEmpty()); assertEquals(sb.toString().trim(), result.getResponse().trim()); } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + fail(e); } } @@ -163,7 +165,7 @@ class TestRealAPIs { assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + fail(e); } } @@ -183,7 +185,7 @@ class TestRealAPIs { assertFalse(chatResult.getResponse().isBlank()); assertEquals(4,chatResult.getChatHistory().size()); } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + fail(e); } } @@ -205,7 +207,7 @@ class TestRealAPIs { assertTrue(chatResult.getResponse().startsWith("NI")); assertEquals(3, chatResult.getChatHistory().size()); } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + fail(e); } } @@ -230,7 +232,7 @@ class TestRealAPIs { assertNotNull(chatResult); assertEquals(sb.toString().trim(), chatResult.getResponse().trim()); } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + fail(e); } } @@ -261,7 +263,7 @@ class TestRealAPIs { } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + fail(e); } } @@ -278,7 +280,7 @@ class TestRealAPIs { OllamaChatResult chatResult = ollamaAPI.chat(requestModel); assertNotNull(chatResult); } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + fail(e); } } @@ -298,7 +300,7 @@ class TestRealAPIs { assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + fail(e); } } @@ -322,7 +324,7 @@ class TestRealAPIs { assertFalse(result.getResponse().isEmpty()); assertEquals(sb.toString().trim(), result.getResponse().trim()); } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + fail(e); } } @@ -342,7 +344,24 @@ class TestRealAPIs { assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - throw new RuntimeException(e); + fail(e); + } + } + + @Test + @Order(3) + public void testEmbedding() { + testEndpointReachability(); + try { + OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder + .getInstance(config.getModel(), "What is the capital of France?").build(); + + List embeddings = ollamaAPI.generateEmbeddings(request); + + assertNotNull(embeddings); + assertFalse(embeddings.isEmpty()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); } } } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/AbstractRequestSerializationTest.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/AbstractRequestSerializationTest.java new file mode 100644 index 0000000..c6b2ff5 --- /dev/null +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/AbstractRequestSerializationTest.java @@ -0,0 +1,35 @@ +package io.github.amithkoujalgi.ollama4j.unittests.jackson; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; + +public abstract class AbstractRequestSerializationTest { + + protected ObjectMapper mapper = Utils.getObjectMapper(); + + protected String serializeRequest(T req) { + try { + return mapper.writeValueAsString(req); + } catch (JsonProcessingException e) { + fail("Could not serialize request!", e); + return null; + } + } + + protected T deserializeRequest(String jsonRequest, Class requestClass) { + try { + return mapper.readValue(jsonRequest, requestClass); + } catch (JsonProcessingException e) { + fail("Could not deserialize jsonRequest!", e); + return null; + } + } + + protected void assertEqualsAfterUnmarshalling(T unmarshalledRequest, + T req) { + assertEquals(req, unmarshalledRequest); + } +} diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java index f5fa5c9..c5a7060 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java @@ -1,7 +1,6 @@ package io.github.amithkoujalgi.ollama4j.unittests.jackson; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; import java.io.File; import java.util.List; @@ -10,21 +9,15 @@ import org.json.JSONObject; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; -import io.github.amithkoujalgi.ollama4j.core.utils.Utils; -public class TestChatRequestSerialization { +public class TestChatRequestSerialization extends AbstractRequestSerializationTest{ private OllamaChatRequestBuilder builder; - private ObjectMapper mapper = Utils.getObjectMapper(); - @BeforeEach public void init() { builder = OllamaChatRequestBuilder.getInstance("DummyModel"); @@ -32,10 +25,9 @@ public class TestChatRequestSerialization { @Test public void testRequestOnlyMandatoryFields() { - OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", - List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build(); + OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt").build(); String jsonRequest = serializeRequest(req); - assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req); } @Test @@ -44,7 +36,7 @@ public class TestChatRequestSerialization { .withMessage(OllamaChatMessageRole.USER, "Some prompt") .build(); String jsonRequest = serializeRequest(req); - assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req); } @Test @@ -52,7 +44,7 @@ public class TestChatRequestSerialization { OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build(); String jsonRequest = serializeRequest(req); - assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req); } @Test @@ -62,7 +54,7 @@ public class TestChatRequestSerialization { .withOptions(b.setMirostat(1).build()).build(); String jsonRequest = serializeRequest(req); - OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest); + OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest,OllamaChatRequestModel.class); assertEqualsAfterUnmarshalling(deserializeRequest, req); assertEquals(1, deserializeRequest.getOptions().get("mirostat")); } @@ -79,28 +71,4 @@ public class TestChatRequestSerialization { String requestFormatProperty = jsonObject.getString("format"); assertEquals("json", requestFormatProperty); } - - private String serializeRequest(OllamaChatRequestModel req) { - try { - return mapper.writeValueAsString(req); - } catch (JsonProcessingException e) { - fail("Could not serialize request!", e); - return null; - } - } - - private OllamaChatRequestModel deserializeRequest(String jsonRequest) { - try { - return mapper.readValue(jsonRequest, OllamaChatRequestModel.class); - } catch (JsonProcessingException e) { - fail("Could not deserialize jsonRequest!", e); - return null; - } - } - - private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest, - OllamaChatRequestModel req) { - assertEquals(req, unmarshalledRequest); - } - } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestEmbeddingsRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestEmbeddingsRequestSerialization.java new file mode 100644 index 0000000..ff1e308 --- /dev/null +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestEmbeddingsRequestSerialization.java @@ -0,0 +1,37 @@ +package io.github.amithkoujalgi.ollama4j.unittests.jackson; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel; +import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder; +import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; + +public class TestEmbeddingsRequestSerialization extends AbstractRequestSerializationTest{ + + private OllamaEmbeddingsRequestBuilder builder; + + @BeforeEach + public void init() { + builder = OllamaEmbeddingsRequestBuilder.getInstance("DummyModel","DummyPrompt"); + } + + @Test + public void testRequestOnlyMandatoryFields() { + OllamaEmbeddingsRequestModel req = builder.build(); + String jsonRequest = serializeRequest(req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaEmbeddingsRequestModel.class), req); + } + + @Test + public void testRequestWithOptions() { + OptionsBuilder b = new OptionsBuilder(); + OllamaEmbeddingsRequestModel req = builder + .withOptions(b.setMirostat(1).build()).build(); + + String jsonRequest = serializeRequest(req); + OllamaEmbeddingsRequestModel deserializeRequest = deserializeRequest(jsonRequest,OllamaEmbeddingsRequestModel.class); + assertEqualsAfterUnmarshalling(deserializeRequest, req); + assertEquals(1, deserializeRequest.getOptions().get("mirostat")); + } +} diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java index 7cf0513..03610f7 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java @@ -1,26 +1,20 @@ package io.github.amithkoujalgi.ollama4j.unittests.jackson; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; import org.json.JSONObject; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; -import io.github.amithkoujalgi.ollama4j.core.utils.Utils; -public class TestGenerateRequestSerialization { +public class TestGenerateRequestSerialization extends AbstractRequestSerializationTest{ private OllamaGenerateRequestBuilder builder; - private ObjectMapper mapper = Utils.getObjectMapper(); - @BeforeEach public void init() { builder = OllamaGenerateRequestBuilder.getInstance("DummyModel"); @@ -31,7 +25,7 @@ public class TestGenerateRequestSerialization { OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build(); String jsonRequest = serializeRequest(req); - assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); + assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest, OllamaGenerateRequestModel.class), req); } @Test @@ -41,7 +35,7 @@ public class TestGenerateRequestSerialization { builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build(); String jsonRequest = serializeRequest(req); - OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest); + OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest, OllamaGenerateRequestModel.class); assertEqualsAfterUnmarshalling(deserializeRequest, req); assertEquals(1, deserializeRequest.getOptions().get("mirostat")); } @@ -59,27 +53,4 @@ public class TestGenerateRequestSerialization { assertEquals("json", requestFormatProperty); } - private String serializeRequest(OllamaGenerateRequestModel req) { - try { - return mapper.writeValueAsString(req); - } catch (JsonProcessingException e) { - fail("Could not serialize request!", e); - return null; - } - } - - private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) { - try { - return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class); - } catch (JsonProcessingException e) { - fail("Could not deserialize jsonRequest!", e); - return null; - } - } - - private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest, - OllamaGenerateRequestModel req) { - assertEquals(req, unmarshalledRequest); - } - }