Adds Builder for EmbedRequests and deprecates old Embedding Models

This commit is contained in:
Markus Klenke 2024-12-02 22:48:33 +01:00
parent 4ef0821932
commit a09f1362e9
6 changed files with 56 additions and 12 deletions

View File

@ -45,7 +45,7 @@ public class OllamaChatRequestBuilder {
try { try {
return Files.readAllBytes(file.toPath()); return Files.readAllBytes(file.toPath());
} catch (IOException e) { } catch (IOException e) {
LOG.warn(String.format("File '%s' could not be accessed, will not add to message!", file.toPath()), e); LOG.warn("File '{}' could not be accessed, will not add to message!", file.toPath(), e);
return new byte[0]; return new byte[0];
} }
}).collect(Collectors.toList()); }).collect(Collectors.toList());
@ -63,9 +63,9 @@ public class OllamaChatRequestBuilder {
try { try {
binaryImages.add(Utils.loadImageBytesFromUrl(imageUrl)); binaryImages.add(Utils.loadImageBytesFromUrl(imageUrl));
} catch (URISyntaxException e) { } catch (URISyntaxException e) {
LOG.warn(String.format("URL '%s' could not be accessed, will not add to message!", imageUrl), e); LOG.warn("URL '{}' could not be accessed, will not add to message!", imageUrl, e);
} catch (IOException e) { } catch (IOException e) {
LOG.warn(String.format("Content of URL '%s' could not be read, will not add to message!", imageUrl), e); LOG.warn("Content of URL '{}' could not be read, will not add to message!", imageUrl, e);
} }
} }
} }

View File

@ -0,0 +1,40 @@
package io.github.ollama4j.models.embeddings;
import io.github.ollama4j.utils.Options;
import java.util.List;
/**
* Builderclass to easily create Requests for Embedding models using ollama.
*/
public class OllamaEmbedRequestBuilder {
private final OllamaEmbedRequestModel request;
private OllamaEmbedRequestBuilder(String model, List<String> input) {
this.request = new OllamaEmbedRequestModel(model,input);
}
public static OllamaEmbedRequestBuilder getInstance(String model, String... input){
return new OllamaEmbedRequestBuilder(model, List.of(input));
}
public OllamaEmbedRequestBuilder withOptions(Options options){
this.request.setOptions(options.getOptionsMap());
return this;
}
public OllamaEmbedRequestBuilder withKeepAlive(String keepAlive){
this.request.setKeepAlive(keepAlive);
return this;
}
public OllamaEmbedRequestBuilder withoutTruncate(){
this.request.setTruncate(false);
return this;
}
public OllamaEmbedRequestModel build() {
return this.request;
}
}

View File

@ -7,6 +7,7 @@ import lombok.Data;
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Data @Data
@Deprecated(since="1.0.90")
public class OllamaEmbeddingResponseModel { public class OllamaEmbeddingResponseModel {
@JsonProperty("embedding") @JsonProperty("embedding")
private List<Double> embedding; private List<Double> embedding;

View File

@ -2,6 +2,7 @@ package io.github.ollama4j.models.embeddings;
import io.github.ollama4j.utils.Options; import io.github.ollama4j.utils.Options;
@Deprecated(since="1.0.90")
public class OllamaEmbeddingsRequestBuilder { public class OllamaEmbeddingsRequestBuilder {
private OllamaEmbeddingsRequestBuilder(String model, String prompt){ private OllamaEmbeddingsRequestBuilder(String model, String prompt){

View File

@ -12,6 +12,7 @@ import lombok.RequiredArgsConstructor;
@Data @Data
@RequiredArgsConstructor @RequiredArgsConstructor
@NoArgsConstructor @NoArgsConstructor
@Deprecated(since="1.0.90")
public class OllamaEmbeddingsRequestModel { public class OllamaEmbeddingsRequestModel {
@NonNull @NonNull
private String model; private String model;

View File

@ -1,36 +1,37 @@
package io.github.ollama4j.unittests.jackson; package io.github.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestBuilder;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.ollama4j.utils.OptionsBuilder; import io.github.ollama4j.utils.OptionsBuilder;
public class TestEmbeddingsRequestSerialization extends AbstractSerializationTest<OllamaEmbeddingsRequestModel> { public class TestEmbedRequestSerialization extends AbstractSerializationTest<OllamaEmbedRequestModel> {
private OllamaEmbeddingsRequestBuilder builder; private OllamaEmbedRequestBuilder builder;
@BeforeEach @BeforeEach
public void init() { public void init() {
builder = OllamaEmbeddingsRequestBuilder.getInstance("DummyModel","DummyPrompt"); builder = OllamaEmbedRequestBuilder.getInstance("DummyModel","DummyPrompt");
} }
@Test @Test
public void testRequestOnlyMandatoryFields() { public void testRequestOnlyMandatoryFields() {
OllamaEmbeddingsRequestModel req = builder.build(); OllamaEmbedRequestModel req = builder.build();
String jsonRequest = serialize(req); String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaEmbeddingsRequestModel.class), req); assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaEmbedRequestModel.class), req);
} }
@Test @Test
public void testRequestWithOptions() { public void testRequestWithOptions() {
OptionsBuilder b = new OptionsBuilder(); OptionsBuilder b = new OptionsBuilder();
OllamaEmbeddingsRequestModel req = builder OllamaEmbedRequestModel req = builder
.withOptions(b.setMirostat(1).build()).build(); .withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serialize(req); String jsonRequest = serialize(req);
OllamaEmbeddingsRequestModel deserializeRequest = deserialize(jsonRequest,OllamaEmbeddingsRequestModel.class); OllamaEmbedRequestModel deserializeRequest = deserialize(jsonRequest,OllamaEmbedRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req); assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat")); assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
} }