diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 695c52e..920ed4e 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -287,4 +287,36 @@ public class OllamaAPI { ollamaAsyncResultCallback.start(); return ollamaAsyncResultCallback; } + + /** + * Generate embeddings for a given text from a model + * + * @param model name of model to generate embeddings from + * @param prompt text to generate embeddings for + * @return embeddings as double[] + * @throws IOException + * @throws ParseException + * @throws OllamaBaseException + */ + public double[] generateEmbeddings(String model, String prompt) throws IOException, ParseException, OllamaBaseException { + String url = this.host + "/api/embeddings"; + String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt); + final HttpPost httpPost = new HttpPost(url); + final StringEntity entity = new StringEntity(jsonData); + httpPost.setEntity(entity); + httpPost.setHeader("Accept", "application/json"); + httpPost.setHeader("Content-type", "application/json"); + try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) { + final int statusCode = response.getCode(); + HttpEntity responseEntity = response.getEntity(); + String responseString = ""; + if (responseEntity != null) { + responseString = EntityUtils.toString(responseEntity, "UTF-8"); + EmbeddingResponse embeddingResponse = objectMapper.readValue(responseString, EmbeddingResponse.class); + return embeddingResponse.getEmbedding(); + } else { + throw new OllamaBaseException(statusCode + " - " + responseString); + } + } + } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java new file mode 100644 index 0000000..a4f6b82 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java @@ -0,0 +1,19 @@ +package io.github.amithkoujalgi.ollama4j.core.models; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public class EmbeddingResponse { + @JsonProperty("embedding") + private double[] embedding; + + public EmbeddingResponse() { + } + + public double[] getEmbedding() { + return embedding; + } + + public void setEmbedding(double[] embedding) { + this.embedding = embedding; + } +}