forked from Mirror/ollama4j
Merge pull request #30 from AgentSchmecker/feature/options_for_embedding_request
Adds options to EmbeddingsRequest
This commit is contained in:
commit
9d887b60a8
12
README.md
12
README.md
@ -67,7 +67,7 @@ In your Maven project, add this dependency:
|
||||
<dependency>
|
||||
<groupId>io.github.amithkoujalgi</groupId>
|
||||
<artifactId>ollama4j</artifactId>
|
||||
<version>1.0.47</version>
|
||||
<version>1.0.57</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
@ -125,15 +125,15 @@ Actions CI workflow.
|
||||
- [x] Update request body creation with Java objects
|
||||
- [ ] Async APIs for images
|
||||
- [ ] Add custom headers to requests
|
||||
- [ ] Add additional params for `ask` APIs such as:
|
||||
- [x] Add additional params for `ask` APIs such as:
|
||||
- [x] `options`: additional model parameters for the Modelfile such as `temperature` -
|
||||
Supported [params](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
|
||||
- [ ] `system`: system prompt to (overrides what is defined in the Modelfile)
|
||||
- [ ] `template`: the full prompt or prompt template (overrides what is defined in the Modelfile)
|
||||
- [ ] `context`: the context parameter returned from a previous request, which can be used to keep a
|
||||
- [x] `system`: system prompt to (overrides what is defined in the Modelfile)
|
||||
- [x] `template`: the full prompt or prompt template (overrides what is defined in the Modelfile)
|
||||
- [x] `context`: the context parameter returned from a previous request, which can be used to keep a
|
||||
short
|
||||
conversational memory
|
||||
- [ ] `stream`: Add support for streaming responses from the model
|
||||
- [x] `stream`: Add support for streaming responses from the model
|
||||
- [ ] Add test cases
|
||||
- [ ] Handle exceptions better (maybe throw more appropriate exceptions)
|
||||
|
||||
|
2
pom.xml
2
pom.xml
@ -99,7 +99,7 @@
|
||||
<configuration>
|
||||
<skipTests>${skipUnitTests}</skipTests>
|
||||
<includes>
|
||||
<include>**/unittests/*.java</include>
|
||||
<include>**/unittests/**/*.java</include>
|
||||
</includes>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
@ -6,10 +6,11 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingResponseModel;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaChatEndpointCaller;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateEndpointCaller;
|
||||
@ -313,8 +314,18 @@ public class OllamaAPI {
|
||||
*/
|
||||
public List<Double> generateEmbeddings(String model, String prompt)
|
||||
throws IOException, InterruptedException, OllamaBaseException {
|
||||
return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings using a {@link OllamaEmbeddingsRequestModel}.
|
||||
*
|
||||
* @param modelRequest request for '/api/embeddings' endpoint
|
||||
* @return embeddings
|
||||
*/
|
||||
public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException{
|
||||
URI uri = URI.create(this.host + "/api/embeddings");
|
||||
String jsonData = new ModelEmbeddingsRequest(model, prompt).toString();
|
||||
String jsonData = modelRequest.toString();
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest.Builder requestBuilder =
|
||||
getRequestBuilderDefault(uri)
|
||||
@ -325,8 +336,8 @@ public class OllamaAPI {
|
||||
int statusCode = response.statusCode();
|
||||
String responseBody = response.body();
|
||||
if (statusCode == 200) {
|
||||
EmbeddingResponse embeddingResponse =
|
||||
Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class);
|
||||
OllamaEmbeddingResponseModel embeddingResponse =
|
||||
Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class);
|
||||
return embeddingResponse.getEmbedding();
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseBody);
|
||||
|
@ -1,4 +1,4 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||
package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
@ -7,7 +7,7 @@ import lombok.Data;
|
||||
|
||||
@SuppressWarnings("unused")
|
||||
@Data
|
||||
public class EmbeddingResponse {
|
||||
public class OllamaEmbeddingResponseModel {
|
||||
@JsonProperty("embedding")
|
||||
private List<Double> embedding;
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
|
||||
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
|
||||
|
||||
public class OllamaEmbeddingsRequestBuilder {
|
||||
|
||||
private OllamaEmbeddingsRequestBuilder(String model, String prompt){
|
||||
request = new OllamaEmbeddingsRequestModel(model, prompt);
|
||||
}
|
||||
|
||||
private OllamaEmbeddingsRequestModel request;
|
||||
|
||||
public static OllamaEmbeddingsRequestBuilder getInstance(String model, String prompt){
|
||||
return new OllamaEmbeddingsRequestBuilder(model, prompt);
|
||||
}
|
||||
|
||||
public OllamaEmbeddingsRequestModel build(){
|
||||
return request;
|
||||
}
|
||||
|
||||
public OllamaEmbeddingsRequestBuilder withOptions(Options options){
|
||||
this.request.setOptions(options.getOptionsMap());
|
||||
return this;
|
||||
}
|
||||
|
||||
public OllamaEmbeddingsRequestBuilder withKeepAlive(String keepAlive){
|
||||
this.request.setKeepAlive(keepAlive);
|
||||
return this;
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
|
||||
|
||||
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
|
||||
import java.util.Map;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
@Data
|
||||
@RequiredArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class OllamaEmbeddingsRequestModel {
|
||||
@NonNull
|
||||
private String model;
|
||||
@NonNull
|
||||
private String prompt;
|
||||
|
||||
protected Map<String, Object> options;
|
||||
@JsonProperty(value = "keep_alive")
|
||||
private String keepAlive;
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
try {
|
||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,23 +0,0 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.models.request;
|
||||
|
||||
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class ModelEmbeddingsRequest {
|
||||
private String model;
|
||||
private String prompt;
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
try {
|
||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
@ -10,6 +10,8 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
@ -61,7 +63,7 @@ class TestRealAPIs {
|
||||
} catch (HttpConnectTimeoutException e) {
|
||||
fail(e.getMessage());
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,7 +75,7 @@ class TestRealAPIs {
|
||||
assertNotNull(ollamaAPI.listModels());
|
||||
ollamaAPI.listModels().forEach(System.out::println);
|
||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,7 +90,7 @@ class TestRealAPIs {
|
||||
.anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
|
||||
assertTrue(found);
|
||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -101,7 +103,7 @@ class TestRealAPIs {
|
||||
assertNotNull(modelDetails);
|
||||
System.out.println(modelDetails);
|
||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -119,7 +121,7 @@ class TestRealAPIs {
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -145,7 +147,7 @@ class TestRealAPIs {
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
assertEquals(sb.toString().trim(), result.getResponse().trim());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -163,7 +165,7 @@ class TestRealAPIs {
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -183,7 +185,7 @@ class TestRealAPIs {
|
||||
assertFalse(chatResult.getResponse().isBlank());
|
||||
assertEquals(4,chatResult.getChatHistory().size());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -205,7 +207,7 @@ class TestRealAPIs {
|
||||
assertTrue(chatResult.getResponse().startsWith("NI"));
|
||||
assertEquals(3, chatResult.getChatHistory().size());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -230,7 +232,7 @@ class TestRealAPIs {
|
||||
assertNotNull(chatResult);
|
||||
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -261,7 +263,7 @@ class TestRealAPIs {
|
||||
|
||||
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -278,7 +280,7 @@ class TestRealAPIs {
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -298,7 +300,7 @@ class TestRealAPIs {
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -322,7 +324,7 @@ class TestRealAPIs {
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
assertEquals(sb.toString().trim(), result.getResponse().trim());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@ -342,7 +344,24 @@ class TestRealAPIs {
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||
throw new RuntimeException(e);
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
public void testEmbedding() {
|
||||
testEndpointReachability();
|
||||
try {
|
||||
OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
|
||||
.getInstance(config.getModel(), "What is the capital of France?").build();
|
||||
|
||||
List<Double> embeddings = ollamaAPI.generateEmbeddings(request);
|
||||
|
||||
assertNotNull(embeddings);
|
||||
assertFalse(embeddings.isEmpty());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,35 @@
|
||||
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||
|
||||
public abstract class AbstractRequestSerializationTest<T> {
|
||||
|
||||
protected ObjectMapper mapper = Utils.getObjectMapper();
|
||||
|
||||
protected String serializeRequest(T req) {
|
||||
try {
|
||||
return mapper.writeValueAsString(req);
|
||||
} catch (JsonProcessingException e) {
|
||||
fail("Could not serialize request!", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
protected T deserializeRequest(String jsonRequest, Class<T> requestClass) {
|
||||
try {
|
||||
return mapper.readValue(jsonRequest, requestClass);
|
||||
} catch (JsonProcessingException e) {
|
||||
fail("Could not deserialize jsonRequest!", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
protected void assertEqualsAfterUnmarshalling(T unmarshalledRequest,
|
||||
T req) {
|
||||
assertEquals(req, unmarshalledRequest);
|
||||
}
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.List;
|
||||
@ -10,21 +9,15 @@ import org.json.JSONObject;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||
|
||||
public class TestChatRequestSerialization {
|
||||
public class TestChatRequestSerialization extends AbstractRequestSerializationTest<OllamaChatRequestModel>{
|
||||
|
||||
private OllamaChatRequestBuilder builder;
|
||||
|
||||
private ObjectMapper mapper = Utils.getObjectMapper();
|
||||
|
||||
@BeforeEach
|
||||
public void init() {
|
||||
builder = OllamaChatRequestBuilder.getInstance("DummyModel");
|
||||
@ -32,10 +25,9 @@ public class TestChatRequestSerialization {
|
||||
|
||||
@Test
|
||||
public void testRequestOnlyMandatoryFields() {
|
||||
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
|
||||
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
|
||||
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt").build();
|
||||
String jsonRequest = serializeRequest(req);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -44,7 +36,7 @@ public class TestChatRequestSerialization {
|
||||
.withMessage(OllamaChatMessageRole.USER, "Some prompt")
|
||||
.build();
|
||||
String jsonRequest = serializeRequest(req);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -52,7 +44,7 @@ public class TestChatRequestSerialization {
|
||||
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
|
||||
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
|
||||
String jsonRequest = serializeRequest(req);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -62,7 +54,7 @@ public class TestChatRequestSerialization {
|
||||
.withOptions(b.setMirostat(1).build()).build();
|
||||
|
||||
String jsonRequest = serializeRequest(req);
|
||||
OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest);
|
||||
OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest,OllamaChatRequestModel.class);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest, req);
|
||||
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
|
||||
}
|
||||
@ -79,28 +71,4 @@ public class TestChatRequestSerialization {
|
||||
String requestFormatProperty = jsonObject.getString("format");
|
||||
assertEquals("json", requestFormatProperty);
|
||||
}
|
||||
|
||||
private String serializeRequest(OllamaChatRequestModel req) {
|
||||
try {
|
||||
return mapper.writeValueAsString(req);
|
||||
} catch (JsonProcessingException e) {
|
||||
fail("Could not serialize request!", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private OllamaChatRequestModel deserializeRequest(String jsonRequest) {
|
||||
try {
|
||||
return mapper.readValue(jsonRequest, OllamaChatRequestModel.class);
|
||||
} catch (JsonProcessingException e) {
|
||||
fail("Could not deserialize jsonRequest!", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest,
|
||||
OllamaChatRequestModel req) {
|
||||
assertEquals(req, unmarshalledRequest);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,37 @@
|
||||
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
|
||||
|
||||
public class TestEmbeddingsRequestSerialization extends AbstractRequestSerializationTest<OllamaEmbeddingsRequestModel>{
|
||||
|
||||
private OllamaEmbeddingsRequestBuilder builder;
|
||||
|
||||
@BeforeEach
|
||||
public void init() {
|
||||
builder = OllamaEmbeddingsRequestBuilder.getInstance("DummyModel","DummyPrompt");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRequestOnlyMandatoryFields() {
|
||||
OllamaEmbeddingsRequestModel req = builder.build();
|
||||
String jsonRequest = serializeRequest(req);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaEmbeddingsRequestModel.class), req);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRequestWithOptions() {
|
||||
OptionsBuilder b = new OptionsBuilder();
|
||||
OllamaEmbeddingsRequestModel req = builder
|
||||
.withOptions(b.setMirostat(1).build()).build();
|
||||
|
||||
String jsonRequest = serializeRequest(req);
|
||||
OllamaEmbeddingsRequestModel deserializeRequest = deserializeRequest(jsonRequest,OllamaEmbeddingsRequestModel.class);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest, req);
|
||||
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
|
||||
}
|
||||
}
|
@ -1,26 +1,20 @@
|
||||
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
import org.json.JSONObject;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||
|
||||
public class TestGenerateRequestSerialization {
|
||||
public class TestGenerateRequestSerialization extends AbstractRequestSerializationTest<OllamaGenerateRequestModel>{
|
||||
|
||||
private OllamaGenerateRequestBuilder builder;
|
||||
|
||||
private ObjectMapper mapper = Utils.getObjectMapper();
|
||||
|
||||
@BeforeEach
|
||||
public void init() {
|
||||
builder = OllamaGenerateRequestBuilder.getInstance("DummyModel");
|
||||
@ -31,7 +25,7 @@ public class TestGenerateRequestSerialization {
|
||||
OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build();
|
||||
|
||||
String jsonRequest = serializeRequest(req);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest, OllamaGenerateRequestModel.class), req);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -41,7 +35,7 @@ public class TestGenerateRequestSerialization {
|
||||
builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build();
|
||||
|
||||
String jsonRequest = serializeRequest(req);
|
||||
OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest);
|
||||
OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest, OllamaGenerateRequestModel.class);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest, req);
|
||||
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
|
||||
}
|
||||
@ -59,27 +53,4 @@ public class TestGenerateRequestSerialization {
|
||||
assertEquals("json", requestFormatProperty);
|
||||
}
|
||||
|
||||
private String serializeRequest(OllamaGenerateRequestModel req) {
|
||||
try {
|
||||
return mapper.writeValueAsString(req);
|
||||
} catch (JsonProcessingException e) {
|
||||
fail("Could not serialize request!", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) {
|
||||
try {
|
||||
return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class);
|
||||
} catch (JsonProcessingException e) {
|
||||
fail("Could not deserialize jsonRequest!", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest,
|
||||
OllamaGenerateRequestModel req) {
|
||||
assertEquals(req, unmarshalledRequest);
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user