Adds options to EmbeddingsRequest

Additionally refactors the Embedding Models and Tests
This commit is contained in:
Markus Klenke 2024-02-25 20:53:45 +00:00
parent a10692e2f1
commit 63d4de4e24
12 changed files with 203 additions and 121 deletions

View File

@ -67,7 +67,7 @@ In your Maven project, add this dependency:
<dependency> <dependency>
<groupId>io.github.amithkoujalgi</groupId> <groupId>io.github.amithkoujalgi</groupId>
<artifactId>ollama4j</artifactId> <artifactId>ollama4j</artifactId>
<version>1.0.47</version> <version>1.0.57</version>
</dependency> </dependency>
``` ```
@ -125,15 +125,15 @@ Actions CI workflow.
- [x] Update request body creation with Java objects - [x] Update request body creation with Java objects
- [ ] Async APIs for images - [ ] Async APIs for images
- [ ] Add custom headers to requests - [ ] Add custom headers to requests
- [ ] Add additional params for `ask` APIs such as: - [x] Add additional params for `ask` APIs such as:
- [x] `options`: additional model parameters for the Modelfile such as `temperature` - - [x] `options`: additional model parameters for the Modelfile such as `temperature` -
Supported [params](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). Supported [params](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
- [ ] `system`: system prompt to (overrides what is defined in the Modelfile) - [x] `system`: system prompt to (overrides what is defined in the Modelfile)
- [ ] `template`: the full prompt or prompt template (overrides what is defined in the Modelfile) - [x] `template`: the full prompt or prompt template (overrides what is defined in the Modelfile)
- [ ] `context`: the context parameter returned from a previous request, which can be used to keep a - [x] `context`: the context parameter returned from a previous request, which can be used to keep a
short short
conversational memory conversational memory
- [ ] `stream`: Add support for streaming responses from the model - [x] `stream`: Add support for streaming responses from the model
- [ ] Add test cases - [ ] Add test cases
- [ ] Handle exceptions better (maybe throw more appropriate exceptions) - [ ] Handle exceptions better (maybe throw more appropriate exceptions)

View File

@ -99,7 +99,7 @@
<configuration> <configuration>
<skipTests>${skipUnitTests}</skipTests> <skipTests>${skipUnitTests}</skipTests>
<includes> <includes>
<include>**/unittests/*.java</include> <include>**/unittests/**/*.java</include>
</includes> </includes>
</configuration> </configuration>
</plugin> </plugin>

View File

@ -6,10 +6,11 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaChatEndpointCaller; import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaChatEndpointCaller;
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateEndpointCaller; import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateEndpointCaller;
@ -313,8 +314,18 @@ public class OllamaAPI {
*/ */
public List<Double> generateEmbeddings(String model, String prompt) public List<Double> generateEmbeddings(String model, String prompt)
throws IOException, InterruptedException, OllamaBaseException { throws IOException, InterruptedException, OllamaBaseException {
return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
}
/**
* Generate embeddings using a {@link OllamaEmbeddingsRequestModel}.
*
* @param modelRequest request for '/api/embeddings' endpoint
* @return embeddings
*/
public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException{
URI uri = URI.create(this.host + "/api/embeddings"); URI uri = URI.create(this.host + "/api/embeddings");
String jsonData = new ModelEmbeddingsRequest(model, prompt).toString(); String jsonData = modelRequest.toString();
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest.Builder requestBuilder = HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri) getRequestBuilderDefault(uri)
@ -325,8 +336,8 @@ public class OllamaAPI {
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseBody = response.body(); String responseBody = response.body();
if (statusCode == 200) { if (statusCode == 200) {
EmbeddingResponse embeddingResponse = OllamaEmbeddingResponseModel embeddingResponse =
Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class); Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class);
return embeddingResponse.getEmbedding(); return embeddingResponse.getEmbedding();
} else { } else {
throw new OllamaBaseException(statusCode + " - " + responseBody); throw new OllamaBaseException(statusCode + " - " + responseBody);

View File

@ -1,4 +1,4 @@
package io.github.amithkoujalgi.ollama4j.core.models; package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
@ -7,7 +7,7 @@ import lombok.Data;
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Data @Data
public class EmbeddingResponse { public class OllamaEmbeddingResponseModel {
@JsonProperty("embedding") @JsonProperty("embedding")
private List<Double> embedding; private List<Double> embedding;
} }

View File

@ -0,0 +1,31 @@
package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
public class OllamaEmbeddingsRequestBuilder {
private OllamaEmbeddingsRequestBuilder(String model, String prompt){
request = new OllamaEmbeddingsRequestModel(model, prompt);
}
private OllamaEmbeddingsRequestModel request;
public static OllamaEmbeddingsRequestBuilder getInstance(String model, String prompt){
return new OllamaEmbeddingsRequestBuilder(model, prompt);
}
public OllamaEmbeddingsRequestModel build(){
return request;
}
public OllamaEmbeddingsRequestBuilder withOptions(Options options){
this.request.setOptions(options.getOptionsMap());
return this;
}
public OllamaEmbeddingsRequestBuilder withKeepAlive(String keepAlive){
this.request.setKeepAlive(keepAlive);
return this;
}
}

View File

@ -0,0 +1,33 @@
package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import java.util.Map;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
@Data
@RequiredArgsConstructor
@NoArgsConstructor
public class OllamaEmbeddingsRequestModel {
@NonNull
private String model;
@NonNull
private String prompt;
protected Map<String, Object> options;
@JsonProperty(value = "keep_alive")
private String keepAlive;
@Override
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -1,23 +0,0 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class ModelEmbeddingsRequest {
private String model;
private String prompt;
@Override
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -10,6 +10,8 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
@ -61,7 +63,7 @@ class TestRealAPIs {
} catch (HttpConnectTimeoutException e) { } catch (HttpConnectTimeoutException e) {
fail(e.getMessage()); fail(e.getMessage());
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -73,7 +75,7 @@ class TestRealAPIs {
assertNotNull(ollamaAPI.listModels()); assertNotNull(ollamaAPI.listModels());
ollamaAPI.listModels().forEach(System.out::println); ollamaAPI.listModels().forEach(System.out::println);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -88,7 +90,7 @@ class TestRealAPIs {
.anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel())); .anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
assertTrue(found); assertTrue(found);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -101,7 +103,7 @@ class TestRealAPIs {
assertNotNull(modelDetails); assertNotNull(modelDetails);
System.out.println(modelDetails); System.out.println(modelDetails);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -119,7 +121,7 @@ class TestRealAPIs {
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -145,7 +147,7 @@ class TestRealAPIs {
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim()); assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -163,7 +165,7 @@ class TestRealAPIs {
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -183,7 +185,7 @@ class TestRealAPIs {
assertFalse(chatResult.getResponse().isBlank()); assertFalse(chatResult.getResponse().isBlank());
assertEquals(4,chatResult.getChatHistory().size()); assertEquals(4,chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -205,7 +207,7 @@ class TestRealAPIs {
assertTrue(chatResult.getResponse().startsWith("NI")); assertTrue(chatResult.getResponse().startsWith("NI"));
assertEquals(3, chatResult.getChatHistory().size()); assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -230,7 +232,7 @@ class TestRealAPIs {
assertNotNull(chatResult); assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim()); assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -261,7 +263,7 @@ class TestRealAPIs {
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -278,7 +280,7 @@ class TestRealAPIs {
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -298,7 +300,7 @@ class TestRealAPIs {
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -322,7 +324,7 @@ class TestRealAPIs {
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim()); assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); fail(e);
} }
} }
@ -342,7 +344,24 @@ class TestRealAPIs {
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); fail(e);
}
}
@Test
@Order(3)
public void testEmbedding() {
testEndpointReachability();
try {
OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
.getInstance(config.getModel(), "What is the capital of France?").build();
List<Double> embeddings = ollamaAPI.generateEmbeddings(request);
assertNotNull(embeddings);
assertFalse(embeddings.isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
} }
} }
} }

View File

@ -0,0 +1,35 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public abstract class AbstractRequestSerializationTest<T> {
protected ObjectMapper mapper = Utils.getObjectMapper();
protected String serializeRequest(T req) {
try {
return mapper.writeValueAsString(req);
} catch (JsonProcessingException e) {
fail("Could not serialize request!", e);
return null;
}
}
protected T deserializeRequest(String jsonRequest, Class<T> requestClass) {
try {
return mapper.readValue(jsonRequest, requestClass);
} catch (JsonProcessingException e) {
fail("Could not deserialize jsonRequest!", e);
return null;
}
}
protected void assertEqualsAfterUnmarshalling(T unmarshalledRequest,
T req) {
assertEquals(req, unmarshalledRequest);
}
}

View File

@ -1,7 +1,6 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson; package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import java.io.File; import java.io.File;
import java.util.List; import java.util.List;
@ -10,21 +9,15 @@ import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class TestChatRequestSerialization { public class TestChatRequestSerialization extends AbstractRequestSerializationTest<OllamaChatRequestModel>{
private OllamaChatRequestBuilder builder; private OllamaChatRequestBuilder builder;
private ObjectMapper mapper = Utils.getObjectMapper();
@BeforeEach @BeforeEach
public void init() { public void init() {
builder = OllamaChatRequestBuilder.getInstance("DummyModel"); builder = OllamaChatRequestBuilder.getInstance("DummyModel");
@ -32,10 +25,9 @@ public class TestChatRequestSerialization {
@Test @Test
public void testRequestOnlyMandatoryFields() { public void testRequestOnlyMandatoryFields() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt").build();
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serializeRequest(req); String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req);
} }
@Test @Test
@ -44,7 +36,7 @@ public class TestChatRequestSerialization {
.withMessage(OllamaChatMessageRole.USER, "Some prompt") .withMessage(OllamaChatMessageRole.USER, "Some prompt")
.build(); .build();
String jsonRequest = serializeRequest(req); String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req);
} }
@Test @Test
@ -52,7 +44,7 @@ public class TestChatRequestSerialization {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build(); List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serializeRequest(req); String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req);
} }
@Test @Test
@ -62,7 +54,7 @@ public class TestChatRequestSerialization {
.withOptions(b.setMirostat(1).build()).build(); .withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req); String jsonRequest = serializeRequest(req);
OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest); OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest,OllamaChatRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req); assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat")); assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
} }
@ -79,28 +71,4 @@ public class TestChatRequestSerialization {
String requestFormatProperty = jsonObject.getString("format"); String requestFormatProperty = jsonObject.getString("format");
assertEquals("json", requestFormatProperty); assertEquals("json", requestFormatProperty);
} }
private String serializeRequest(OllamaChatRequestModel req) {
try {
return mapper.writeValueAsString(req);
} catch (JsonProcessingException e) {
fail("Could not serialize request!", e);
return null;
}
}
private OllamaChatRequestModel deserializeRequest(String jsonRequest) {
try {
return mapper.readValue(jsonRequest, OllamaChatRequestModel.class);
} catch (JsonProcessingException e) {
fail("Could not deserialize jsonRequest!", e);
return null;
}
}
private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest,
OllamaChatRequestModel req) {
assertEquals(req, unmarshalledRequest);
}
} }

View File

@ -0,0 +1,37 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
public class TestEmbeddingsRequestSerialization extends AbstractRequestSerializationTest<OllamaEmbeddingsRequestModel>{
private OllamaEmbeddingsRequestBuilder builder;
@BeforeEach
public void init() {
builder = OllamaEmbeddingsRequestBuilder.getInstance("DummyModel","DummyPrompt");
}
@Test
public void testRequestOnlyMandatoryFields() {
OllamaEmbeddingsRequestModel req = builder.build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaEmbeddingsRequestModel.class), req);
}
@Test
public void testRequestWithOptions() {
OptionsBuilder b = new OptionsBuilder();
OllamaEmbeddingsRequestModel req = builder
.withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req);
OllamaEmbeddingsRequestModel deserializeRequest = deserializeRequest(jsonRequest,OllamaEmbeddingsRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
}
}

View File

@ -1,26 +1,20 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson; package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import org.json.JSONObject; import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class TestGenerateRequestSerialization { public class TestGenerateRequestSerialization extends AbstractRequestSerializationTest<OllamaGenerateRequestModel>{
private OllamaGenerateRequestBuilder builder; private OllamaGenerateRequestBuilder builder;
private ObjectMapper mapper = Utils.getObjectMapper();
@BeforeEach @BeforeEach
public void init() { public void init() {
builder = OllamaGenerateRequestBuilder.getInstance("DummyModel"); builder = OllamaGenerateRequestBuilder.getInstance("DummyModel");
@ -31,7 +25,7 @@ public class TestGenerateRequestSerialization {
OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build(); OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build();
String jsonRequest = serializeRequest(req); String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req); assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest, OllamaGenerateRequestModel.class), req);
} }
@Test @Test
@ -41,7 +35,7 @@ public class TestGenerateRequestSerialization {
builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build(); builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req); String jsonRequest = serializeRequest(req);
OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest); OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest, OllamaGenerateRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req); assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat")); assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
} }
@ -59,27 +53,4 @@ public class TestGenerateRequestSerialization {
assertEquals("json", requestFormatProperty); assertEquals("json", requestFormatProperty);
} }
private String serializeRequest(OllamaGenerateRequestModel req) {
try {
return mapper.writeValueAsString(req);
} catch (JsonProcessingException e) {
fail("Could not serialize request!", e);
return null;
}
}
private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) {
try {
return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class);
} catch (JsonProcessingException e) {
fail("Could not deserialize jsonRequest!", e);
return null;
}
}
private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest,
OllamaGenerateRequestModel req) {
assertEquals(req, unmarshalledRequest);
}
} }