diff --git a/docs/docs/apis-generate/generate-embeddings.md b/docs/docs/apis-generate/generate-embeddings.md index 586b215..e465581 100644 --- a/docs/docs/apis-generate/generate-embeddings.md +++ b/docs/docs/apis-generate/generate-embeddings.md @@ -8,12 +8,85 @@ Generate embeddings from a model. Parameters: +- `model`: name of model to generate embeddings from +- `input`: text/s to generate embeddings for + +```java +import io.github.ollama4j.OllamaAPI; +import io.github.ollama4j.types.OllamaModelType; +import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; +import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class Main { + + public static void main(String[] args) { + + String host = "http://localhost:11434/"; + + OllamaAPI ollamaAPI = new OllamaAPI(host); + + OllamaEmbedResponseModel embeddings = ollamaAPI.embed("all-minilm", Arrays.asList("Why is the sky blue?", "Why is the grass green?")); + + System.out.println(embeddings); + } +} +``` + +Or, using the `OllamaEmbedResponseModel`: + +```java +import io.github.ollama4j.OllamaAPI; +import io.github.ollama4j.types.OllamaModelType; +import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; +import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class Main { + + public static void main(String[] args) { + + String host = "http://localhost:11434/"; + + OllamaAPI ollamaAPI = new OllamaAPI(host); + + OllamaEmbedResponseModel embeddings = ollamaAPI.embed(new OllamaEmbedRequestModel("all-minilm", Arrays.asList("Why is the sky blue?", "Why is the grass green?"))); + + System.out.println(embeddings); + } +} +``` + +You will get a response similar to: + +```json +{ + "model": "all-minilm", + "embeddings": [[-0.034674067, 0.030984823, 0.0067988685]], + "total_duration": 14173700, + "load_duration": 1198800, + "prompt_eval_count": 2 +} +```` + +:::note + +This is a deprecated API + +::: + +Parameters: + - `model`: name of model to generate embeddings from - `prompt`: text to generate embeddings for ```java import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.types.OllamaModelType; + import java.util.List; public class Main { @@ -40,11 +113,6 @@ You will get a response similar to: 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, - -0.8924556970596313, - 0.8785552978515625, - -0.34576427936553955, - 0.5742510557174683, - -0.04222835972905159, - -0.137906014919281 + -0.8924556970596313 ] ``` \ No newline at end of file diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index d3c68a2..f32c069 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -7,8 +7,10 @@ import io.github.ollama4j.models.chat.OllamaChatMessage; import io.github.ollama4j.models.chat.OllamaChatRequest; import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; import io.github.ollama4j.models.chat.OllamaChatResult; +import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; +import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateRequest; import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.ps.ModelsProcessResponse; @@ -342,7 +344,9 @@ public class OllamaAPI { * @param model name of model to generate embeddings from * @param prompt text to generate embeddings for * @return embeddings + * @deprecated Use {@link #embed(String, List)} instead. */ + @Deprecated public List generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException { return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt)); @@ -353,7 +357,9 @@ public class OllamaAPI { * * @param modelRequest request for '/api/embeddings' endpoint * @return embeddings + * @deprecated Use {@link #embed(OllamaEmbedRequestModel)} instead. */ + @Deprecated public List generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException { URI uri = URI.create(this.host + "/api/embeddings"); String jsonData = modelRequest.toString(); @@ -375,6 +381,47 @@ public class OllamaAPI { } } + /** + * Generate embeddings for a given text from a model + * + * @param model name of model to generate embeddings from + * @param inputs text/s to generate embeddings for + * @return embeddings + */ + public OllamaEmbedResponseModel embed(String model, List inputs) + throws IOException, InterruptedException, OllamaBaseException { + return embed(new OllamaEmbedRequestModel(model, inputs)); + } + + /** + * Generate embeddings using a {@link OllamaEmbedRequestModel}. + * + * @param modelRequest request for '/api/embed' endpoint + * @return embeddings + */ + public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) + throws IOException, InterruptedException, OllamaBaseException { + URI uri = URI.create(this.host + "/api/embed"); + String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest); + HttpClient httpClient = HttpClient.newHttpClient(); + + HttpRequest request = HttpRequest.newBuilder(uri) + .header("Accept", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(jsonData)) + .build(); + + HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + int statusCode = response.statusCode(); + String responseBody = response.body(); + + if (statusCode == 200) { + OllamaEmbedResponseModel embeddingResponse = + Utils.getObjectMapper().readValue(responseBody, OllamaEmbedResponseModel.class); + return embeddingResponse; + } else { + throw new OllamaBaseException(statusCode + " - " + responseBody); + } + } /** * Generate response for a question to a model running on Ollama server. This is a sync/blocking diff --git a/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedRequestModel.java b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedRequestModel.java new file mode 100644 index 0000000..8cb2002 --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedRequestModel.java @@ -0,0 +1,41 @@ +package io.github.ollama4j.models.embeddings; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; + +import java.util.List; +import java.util.Map; + +import static io.github.ollama4j.utils.Utils.getObjectMapper; + +@Data +@RequiredArgsConstructor +@NoArgsConstructor +public class OllamaEmbedRequestModel { + @NonNull + private String model; + + @NonNull + private List input; + + private Map options; + + @JsonProperty(value = "keep_alive") + private String keepAlive; + + @JsonProperty(value = "truncate") + private Boolean truncate = true; + + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedResponseModel.java b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedResponseModel.java new file mode 100644 index 0000000..b4f808c --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedResponseModel.java @@ -0,0 +1,25 @@ +package io.github.ollama4j.models.embeddings; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +import java.util.List; + +@SuppressWarnings("unused") +@Data +public class OllamaEmbedResponseModel { + @JsonProperty("model") + private String model; + + @JsonProperty("embeddings") + private List> embeddings; + + @JsonProperty("total_duration") + private long totalDuration; + + @JsonProperty("load_duration") + private long loadDuration; + + @JsonProperty("prompt_eval_count") + private int promptEvalCount; +} diff --git a/src/main/java/io/github/ollama4j/types/OllamaModelType.java b/src/main/java/io/github/ollama4j/types/OllamaModelType.java index ccd866f..36fafb4 100644 --- a/src/main/java/io/github/ollama4j/types/OllamaModelType.java +++ b/src/main/java/io/github/ollama4j/types/OllamaModelType.java @@ -10,12 +10,9 @@ package io.github.ollama4j.types; public class OllamaModelType { public static final String GEMMA = "gemma"; public static final String GEMMA2 = "gemma2"; - - public static final String LLAMA2 = "llama2"; public static final String LLAMA3 = "llama3"; public static final String LLAMA3_1 = "llama3.1"; - public static final String MISTRAL = "mistral"; public static final String MIXTRAL = "mixtral"; public static final String LLAVA = "llava"; @@ -35,7 +32,6 @@ public class OllamaModelType { public static final String ZEPHYR = "zephyr"; public static final String OPENHERMES = "openhermes"; public static final String QWEN = "qwen"; - public static final String QWEN2 = "qwen2"; public static final String WIZARDCODER = "wizardcoder"; public static final String LLAMA2_CHINESE = "llama2-chinese";