Added support for new embed API /api/embed

This commit is contained in:
Amith Koujalgi 2024-10-30 00:03:49 +05:30
parent 7fabead249
commit bd1a57c7e0
No known key found for this signature in database
GPG Key ID: E29A37746AF94B70
5 changed files with 187 additions and 10 deletions

View File

@ -8,12 +8,85 @@ Generate embeddings from a model.
Parameters: Parameters:
- `model`: name of model to generate embeddings from
- `input`: text/s to generate embeddings for
```java
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.types.OllamaModelType;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaEmbedResponseModel embeddings = ollamaAPI.embed("all-minilm", Arrays.asList("Why is the sky blue?", "Why is the grass green?"));
System.out.println(embeddings);
}
}
```
Or, using the `OllamaEmbedResponseModel`:
```java
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.types.OllamaModelType;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaEmbedResponseModel embeddings = ollamaAPI.embed(new OllamaEmbedRequestModel("all-minilm", Arrays.asList("Why is the sky blue?", "Why is the grass green?")));
System.out.println(embeddings);
}
}
```
You will get a response similar to:
```json
{
"model": "all-minilm",
"embeddings": [[-0.034674067, 0.030984823, 0.0067988685]],
"total_duration": 14173700,
"load_duration": 1198800,
"prompt_eval_count": 2
}
````
:::note
This is a deprecated API
:::
Parameters:
- `model`: name of model to generate embeddings from - `model`: name of model to generate embeddings from
- `prompt`: text to generate embeddings for - `prompt`: text to generate embeddings for
```java ```java
import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.types.OllamaModelType; import io.github.ollama4j.types.OllamaModelType;
import java.util.List; import java.util.List;
public class Main { public class Main {
@ -40,11 +113,6 @@ You will get a response similar to:
0.009260174818336964, 0.009260174818336964,
0.23178744316101074, 0.23178744316101074,
-0.2916173040866852, -0.2916173040866852,
-0.8924556970596313, -0.8924556970596313
0.8785552978515625,
-0.34576427936553955,
0.5742510557174683,
-0.04222835972905159,
-0.137906014919281
] ]
``` ```

View File

@ -7,8 +7,10 @@ import io.github.ollama4j.models.chat.OllamaChatMessage;
import io.github.ollama4j.models.chat.OllamaChatRequest; import io.github.ollama4j.models.chat.OllamaChatRequest;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import io.github.ollama4j.models.chat.OllamaChatResult; import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateRequest; import io.github.ollama4j.models.generate.OllamaGenerateRequest;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaStreamHandler;
import io.github.ollama4j.models.ps.ModelsProcessResponse; import io.github.ollama4j.models.ps.ModelsProcessResponse;
@ -342,7 +344,9 @@ public class OllamaAPI {
* @param model name of model to generate embeddings from * @param model name of model to generate embeddings from
* @param prompt text to generate embeddings for * @param prompt text to generate embeddings for
* @return embeddings * @return embeddings
* @deprecated Use {@link #embed(String, List<String>)} instead.
*/ */
@Deprecated
public List<Double> generateEmbeddings(String model, String prompt) public List<Double> generateEmbeddings(String model, String prompt)
throws IOException, InterruptedException, OllamaBaseException { throws IOException, InterruptedException, OllamaBaseException {
return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt)); return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
@ -353,7 +357,9 @@ public class OllamaAPI {
* *
* @param modelRequest request for '/api/embeddings' endpoint * @param modelRequest request for '/api/embeddings' endpoint
* @return embeddings * @return embeddings
* @deprecated Use {@link #embed(OllamaEmbedRequestModel)} instead.
*/ */
@Deprecated
public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException { public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException {
URI uri = URI.create(this.host + "/api/embeddings"); URI uri = URI.create(this.host + "/api/embeddings");
String jsonData = modelRequest.toString(); String jsonData = modelRequest.toString();
@ -375,6 +381,47 @@ public class OllamaAPI {
} }
} }
/**
* Generate embeddings for a given text from a model
*
* @param model name of model to generate embeddings from
* @param inputs text/s to generate embeddings for
* @return embeddings
*/
public OllamaEmbedResponseModel embed(String model, List<String> inputs)
throws IOException, InterruptedException, OllamaBaseException {
return embed(new OllamaEmbedRequestModel(model, inputs));
}
/**
* Generate embeddings using a {@link OllamaEmbedRequestModel}.
*
* @param modelRequest request for '/api/embed' endpoint
* @return embeddings
*/
public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest)
throws IOException, InterruptedException, OllamaBaseException {
URI uri = URI.create(this.host + "/api/embed");
String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest request = HttpRequest.newBuilder(uri)
.header("Accept", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
.build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseBody = response.body();
if (statusCode == 200) {
OllamaEmbedResponseModel embeddingResponse =
Utils.getObjectMapper().readValue(responseBody, OllamaEmbedResponseModel.class);
return embeddingResponse;
} else {
throw new OllamaBaseException(statusCode + " - " + responseBody);
}
}
/** /**
* Generate response for a question to a model running on Ollama server. This is a sync/blocking * Generate response for a question to a model running on Ollama server. This is a sync/blocking

View File

@ -0,0 +1,41 @@
package io.github.ollama4j.models.embeddings;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import java.util.List;
import java.util.Map;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
@Data
@RequiredArgsConstructor
@NoArgsConstructor
public class OllamaEmbedRequestModel {
@NonNull
private String model;
@NonNull
private List<String> input;
private Map<String, Object> options;
@JsonProperty(value = "keep_alive")
private String keepAlive;
@JsonProperty(value = "truncate")
private Boolean truncate = true;
@Override
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,25 @@
package io.github.ollama4j.models.embeddings;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.List;
@SuppressWarnings("unused")
@Data
public class OllamaEmbedResponseModel {
@JsonProperty("model")
private String model;
@JsonProperty("embeddings")
private List<List<Double>> embeddings;
@JsonProperty("total_duration")
private long totalDuration;
@JsonProperty("load_duration")
private long loadDuration;
@JsonProperty("prompt_eval_count")
private int promptEvalCount;
}

View File

@ -10,12 +10,9 @@ package io.github.ollama4j.types;
public class OllamaModelType { public class OllamaModelType {
public static final String GEMMA = "gemma"; public static final String GEMMA = "gemma";
public static final String GEMMA2 = "gemma2"; public static final String GEMMA2 = "gemma2";
public static final String LLAMA2 = "llama2"; public static final String LLAMA2 = "llama2";
public static final String LLAMA3 = "llama3"; public static final String LLAMA3 = "llama3";
public static final String LLAMA3_1 = "llama3.1"; public static final String LLAMA3_1 = "llama3.1";
public static final String MISTRAL = "mistral"; public static final String MISTRAL = "mistral";
public static final String MIXTRAL = "mixtral"; public static final String MIXTRAL = "mixtral";
public static final String LLAVA = "llava"; public static final String LLAVA = "llava";
@ -35,7 +32,6 @@ public class OllamaModelType {
public static final String ZEPHYR = "zephyr"; public static final String ZEPHYR = "zephyr";
public static final String OPENHERMES = "openhermes"; public static final String OPENHERMES = "openhermes";
public static final String QWEN = "qwen"; public static final String QWEN = "qwen";
public static final String QWEN2 = "qwen2"; public static final String QWEN2 = "qwen2";
public static final String WIZARDCODER = "wizardcoder"; public static final String WIZARDCODER = "wizardcoder";
public static final String LLAMA2_CHINESE = "llama2-chinese"; public static final String LLAMA2_CHINESE = "llama2-chinese";