From bd1a57c7e035c04ff9300bc21d508da89c084a8a Mon Sep 17 00:00:00 2001
From: Amith Koujalgi <koujalgi.amith@gmail.com>
Date: Wed, 30 Oct 2024 00:03:49 +0530
Subject: [PATCH] Added support for new embed API `/api/embed`

---
 .../docs/apis-generate/generate-embeddings.md | 80 +++++++++++++++++--
 .../java/io/github/ollama4j/OllamaAPI.java    | 47 +++++++++++
 .../embeddings/OllamaEmbedRequestModel.java   | 41 ++++++++++
 .../embeddings/OllamaEmbedResponseModel.java  | 25 ++++++
 .../ollama4j/types/OllamaModelType.java       |  4 -
 5 files changed, 187 insertions(+), 10 deletions(-)
 create mode 100644 src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedRequestModel.java
 create mode 100644 src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedResponseModel.java

diff --git a/docs/docs/apis-generate/generate-embeddings.md b/docs/docs/apis-generate/generate-embeddings.md
index 586b215..e465581 100644
--- a/docs/docs/apis-generate/generate-embeddings.md
+++ b/docs/docs/apis-generate/generate-embeddings.md
@@ -8,12 +8,85 @@ Generate embeddings from a model.
 
 Parameters:
 
+- `model`: name of model to generate embeddings from
+- `input`: text/s to generate embeddings for
+
+```java
+import io.github.ollama4j.OllamaAPI;
+import io.github.ollama4j.types.OllamaModelType;
+import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
+import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+public class Main {
+
+    public static void main(String[] args) {
+
+        String host = "http://localhost:11434/";
+
+        OllamaAPI ollamaAPI = new OllamaAPI(host);
+
+        OllamaEmbedResponseModel embeddings = ollamaAPI.embed("all-minilm", Arrays.asList("Why is the sky blue?", "Why is the grass green?"));
+
+        System.out.println(embeddings);
+    }
+}
+```
+
+Or, using the `OllamaEmbedResponseModel`:
+
+```java
+import io.github.ollama4j.OllamaAPI;
+import io.github.ollama4j.types.OllamaModelType;
+import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
+import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+public class Main {
+
+    public static void main(String[] args) {
+
+        String host = "http://localhost:11434/";
+
+        OllamaAPI ollamaAPI = new OllamaAPI(host);
+
+        OllamaEmbedResponseModel embeddings = ollamaAPI.embed(new OllamaEmbedRequestModel("all-minilm", Arrays.asList("Why is the sky blue?", "Why is the grass green?")));
+
+        System.out.println(embeddings);
+    }
+}
+```
+
+You will get a response similar to:
+
+```json
+{
+    "model": "all-minilm",
+    "embeddings": [[-0.034674067, 0.030984823, 0.0067988685]],
+    "total_duration": 14173700,
+    "load_duration": 1198800,
+    "prompt_eval_count": 2
+}
+````
+
+:::note
+
+This is a deprecated API
+
+:::
+
+Parameters:
+
 - `model`: name of model to generate embeddings from
 - `prompt`: text to generate embeddings for
 
 ```java
 import io.github.ollama4j.OllamaAPI;
 import io.github.ollama4j.types.OllamaModelType;
+
 import java.util.List;
 
 public class Main {
@@ -40,11 +113,6 @@ You will get a response similar to:
     0.009260174818336964,
     0.23178744316101074,
     -0.2916173040866852,
-    -0.8924556970596313,
-    0.8785552978515625,
-    -0.34576427936553955,
-    0.5742510557174683,
-    -0.04222835972905159,
-    -0.137906014919281
+    -0.8924556970596313
 ]
 ```
\ No newline at end of file
diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java
index d3c68a2..f32c069 100644
--- a/src/main/java/io/github/ollama4j/OllamaAPI.java
+++ b/src/main/java/io/github/ollama4j/OllamaAPI.java
@@ -7,8 +7,10 @@ import io.github.ollama4j.models.chat.OllamaChatMessage;
 import io.github.ollama4j.models.chat.OllamaChatRequest;
 import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
 import io.github.ollama4j.models.chat.OllamaChatResult;
+import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
 import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
 import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
+import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
 import io.github.ollama4j.models.generate.OllamaGenerateRequest;
 import io.github.ollama4j.models.generate.OllamaStreamHandler;
 import io.github.ollama4j.models.ps.ModelsProcessResponse;
@@ -342,7 +344,9 @@ public class OllamaAPI {
      * @param model  name of model to generate embeddings from
      * @param prompt text to generate embeddings for
      * @return embeddings
+     * @deprecated Use {@link #embed(String, List<String>)} instead.
      */
+    @Deprecated
     public List<Double> generateEmbeddings(String model, String prompt)
             throws IOException, InterruptedException, OllamaBaseException {
         return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
@@ -353,7 +357,9 @@ public class OllamaAPI {
      *
      * @param modelRequest request for '/api/embeddings' endpoint
      * @return embeddings
+     * @deprecated Use {@link #embed(OllamaEmbedRequestModel)} instead.
      */
+    @Deprecated
     public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException {
         URI uri = URI.create(this.host + "/api/embeddings");
         String jsonData = modelRequest.toString();
@@ -375,6 +381,47 @@ public class OllamaAPI {
         }
     }
 
+    /**
+     * Generate embeddings for a given text from a model
+     *
+     * @param model  name of model to generate embeddings from
+     * @param inputs text/s to generate embeddings for
+     * @return embeddings
+     */
+    public OllamaEmbedResponseModel embed(String model, List<String> inputs)
+            throws IOException, InterruptedException, OllamaBaseException {
+        return embed(new OllamaEmbedRequestModel(model, inputs));
+    }
+
+    /**
+     * Generate embeddings using a {@link OllamaEmbedRequestModel}.
+     *
+     * @param modelRequest request for '/api/embed' endpoint
+     * @return embeddings
+     */
+    public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest)
+            throws IOException, InterruptedException, OllamaBaseException {
+        URI uri = URI.create(this.host + "/api/embed");
+        String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
+        HttpClient httpClient = HttpClient.newHttpClient();
+
+        HttpRequest request = HttpRequest.newBuilder(uri)
+                .header("Accept", "application/json")
+                .POST(HttpRequest.BodyPublishers.ofString(jsonData))
+                .build();
+
+        HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
+        int statusCode = response.statusCode();
+        String responseBody = response.body();
+
+        if (statusCode == 200) {
+            OllamaEmbedResponseModel embeddingResponse =
+                    Utils.getObjectMapper().readValue(responseBody, OllamaEmbedResponseModel.class);
+            return embeddingResponse;
+        } else {
+            throw new OllamaBaseException(statusCode + " - " + responseBody);
+        }
+    }
 
     /**
      * Generate response for a question to a model running on Ollama server. This is a sync/blocking
diff --git a/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedRequestModel.java b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedRequestModel.java
new file mode 100644
index 0000000..8cb2002
--- /dev/null
+++ b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedRequestModel.java
@@ -0,0 +1,41 @@
+package io.github.ollama4j.models.embeddings;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.NonNull;
+import lombok.RequiredArgsConstructor;
+
+import java.util.List;
+import java.util.Map;
+
+import static io.github.ollama4j.utils.Utils.getObjectMapper;
+
+@Data
+@RequiredArgsConstructor
+@NoArgsConstructor
+public class OllamaEmbedRequestModel {
+    @NonNull
+    private String model;
+
+    @NonNull
+    private List<String> input;
+
+    private Map<String, Object> options;
+
+    @JsonProperty(value = "keep_alive")
+    private String keepAlive;
+
+    @JsonProperty(value = "truncate")
+    private Boolean truncate = true;
+
+    @Override
+    public String toString() {
+        try {
+            return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
+        } catch (JsonProcessingException e) {
+            throw new RuntimeException(e);
+        }
+    }
+}
diff --git a/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedResponseModel.java b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedResponseModel.java
new file mode 100644
index 0000000..b4f808c
--- /dev/null
+++ b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbedResponseModel.java
@@ -0,0 +1,25 @@
+package io.github.ollama4j.models.embeddings;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import lombok.Data;
+
+import java.util.List;
+
+@SuppressWarnings("unused")
+@Data
+public class OllamaEmbedResponseModel {
+    @JsonProperty("model")
+    private String model;
+
+    @JsonProperty("embeddings")
+    private List<List<Double>> embeddings;
+
+    @JsonProperty("total_duration")
+    private long totalDuration;
+
+    @JsonProperty("load_duration")
+    private long loadDuration;
+
+    @JsonProperty("prompt_eval_count")
+    private int promptEvalCount;
+}
diff --git a/src/main/java/io/github/ollama4j/types/OllamaModelType.java b/src/main/java/io/github/ollama4j/types/OllamaModelType.java
index ccd866f..36fafb4 100644
--- a/src/main/java/io/github/ollama4j/types/OllamaModelType.java
+++ b/src/main/java/io/github/ollama4j/types/OllamaModelType.java
@@ -10,12 +10,9 @@ package io.github.ollama4j.types;
 public class OllamaModelType {
     public static final String GEMMA = "gemma";
     public static final String GEMMA2 = "gemma2";
-
-
     public static final String LLAMA2 = "llama2";
     public static final String LLAMA3 = "llama3";
     public static final String LLAMA3_1 = "llama3.1";
-
     public static final String MISTRAL = "mistral";
     public static final String MIXTRAL = "mixtral";
     public static final String LLAVA = "llava";
@@ -35,7 +32,6 @@ public class OllamaModelType {
     public static final String ZEPHYR = "zephyr";
     public static final String OPENHERMES = "openhermes";
     public static final String QWEN = "qwen";
-
     public static final String QWEN2 = "qwen2";
     public static final String WIZARDCODER = "wizardcoder";
     public static final String LLAMA2_CHINESE = "llama2-chinese";