diff --git a/README.md b/README.md
index a561b55..3ac83c6 100644
--- a/README.md
+++ b/README.md
@@ -44,13 +44,13 @@ for [Ollama](https://github.com/jmorganca/ollama/blob/main/docs/api.md) APIs.
[![][ollama-shield]][ollama] Or [![][ollama-docker-shield]][ollama-docker]
[ollama]: https://ollama.ai/
+
[ollama-shield]: https://img.shields.io/badge/Ollama-Local_Installation-blue.svg?style=for-the-badge&labelColor=gray
[ollama-docker]: https://hub.docker.com/r/ollama/ollama
+
[ollama-docker-shield]: https://img.shields.io/badge/Ollama-Docker-blue.svg?style=for-the-badge&labelColor=gray
-
-
#### Installation
In your Maven project, add this dependency available in
@@ -59,9 +59,9 @@ the [Central Repository](https://s01.oss.sonatype.org/#nexus-search;quick~ollama
```xml
- io.github.amithkoujalgi
- ollama4j
- 1.0-SNAPSHOT
+ io.github.amithkoujalgi
+ ollama4j
+ 1.0-SNAPSHOT
```
@@ -71,10 +71,10 @@ your `pom.xml`:
```xml
-
- ollama4j-from-ossrh
- https://s01.oss.sonatype.org/content/repositories/snapshots
-
+
+ ollama4j-from-ossrh
+ https://s01.oss.sonatype.org/content/repositories/snapshots
+
```
@@ -113,13 +113,13 @@ Instantiate `OllamaAPI`
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
- // set verbose - true/false
- ollamaAPI.setVerbose(true);
- }
+ // set verbose - true/false
+ ollamaAPI.setVerbose(true);
+ }
}
```
@@ -128,11 +128,11 @@ public class Main {
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
- ollamaAPI.pullModel(OllamaModelType.LLAMA2);
- }
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
+ ollamaAPI.pullModel(OllamaModelType.LLAMA2);
+ }
}
```
@@ -143,12 +143,12 @@ _Find the list of available models from Ollama [here](https://ollama.ai/library)
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
- List models = ollamaAPI.listModels();
- models.forEach(model -> System.out.println(model.getName()));
- }
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
+ List models = ollamaAPI.listModels();
+ models.forEach(model -> System.out.println(model.getName()));
+ }
}
```
@@ -164,12 +164,12 @@ sqlcoder:latest
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
- ModelDetail modelDetails = ollamaAPI.getModelDetails(OllamaModelType.LLAMA2);
- System.out.println(modelDetails);
- }
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
+ ModelDetail modelDetails = ollamaAPI.getModelDetails(OllamaModelType.LLAMA2);
+ System.out.println(modelDetails);
+ }
}
```
@@ -189,11 +189,11 @@ Response:
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
- ollamaAPI.createModel("mycustommodel", "/path/to/modelfile/on/ollama-server");
- }
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
+ ollamaAPI.createModel("mycustommodel", "/path/to/modelfile/on/ollama-server");
+ }
}
```
@@ -202,12 +202,12 @@ public class Main {
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
- ollamaAPI.setVerbose(false);
- ollamaAPI.deleteModel("mycustommodel", true);
- }
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
+ ollamaAPI.setVerbose(false);
+ ollamaAPI.deleteModel("mycustommodel", true);
+ }
}
```
@@ -216,13 +216,13 @@ public class Main {
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
- List embeddings = ollamaAPI.generateEmbeddings(OllamaModelType.LLAMA2,
- "Here is an article about llamas...");
- embeddings.forEach(System.out::println);
- }
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
+ List embeddings = ollamaAPI.generateEmbeddings(OllamaModelType.LLAMA2,
+ "Here is an article about llamas...");
+ embeddings.forEach(System.out::println);
+ }
}
```
@@ -233,12 +233,12 @@ public class Main {
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
- String response = ollamaAPI.ask(OllamaModelType.LLAMA2, "Who are you?");
- System.out.println(response);
- }
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
+ String response = ollamaAPI.ask(OllamaModelType.LLAMA2, "Who are you?");
+ System.out.println(response);
+ }
}
```
@@ -247,20 +247,20 @@ public class Main {
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
- OllamaAsyncResultCallback ollamaAsyncResultCallback = ollamaAPI.askAsync(OllamaModelType.LLAMA2,
- "Who are you?");
- while (true) {
- if (ollamaAsyncResultCallback.isComplete()) {
- System.out.println(ollamaAsyncResultCallback.getResponse());
- break;
- }
- // introduce sleep to check for status with a time interval
- // Thread.sleep(1000);
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
+ OllamaAsyncResultCallback ollamaAsyncResultCallback = ollamaAPI.askAsync(OllamaModelType.LLAMA2,
+ "Who are you?");
+ while (true) {
+ if (ollamaAsyncResultCallback.isComplete()) {
+ System.out.println(ollamaAsyncResultCallback.getResponse());
+ break;
+ }
+ // introduce sleep to check for status with a time interval
+ // Thread.sleep(1000);
+ }
}
- }
}
```
@@ -280,14 +280,14 @@ You'd then get a response from the model:
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
- String prompt = "List all cricket world cup teams of 2019.";
- String response = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt);
- System.out.println(response);
- }
+ String prompt = "List all cricket world cup teams of 2019.";
+ String response = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt);
+ System.out.println(response);
+ }
}
```
@@ -316,15 +316,15 @@ You'd then get a response from the model:
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
- String prompt = SamplePrompts.getSampleDatabasePromptWithQuestion(
- "List all customer names who have bought one or more products");
- String response = ollamaAPI.ask(OllamaModelType.SQLCODER, prompt);
- System.out.println(response);
- }
+ String prompt = SamplePrompts.getSampleDatabasePromptWithQuestion(
+ "List all customer names who have bought one or more products");
+ String response = ollamaAPI.ask(OllamaModelType.SQLCODER, prompt);
+ System.out.println(response);
+ }
}
```
@@ -351,17 +351,17 @@ With Files:
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
- ollamaAPI.setRequestTimeoutSeconds(10);
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
+ ollamaAPI.setRequestTimeoutSeconds(10);
- OllamaResult response = ollamaAPI.askWithImageFiles(OllamaModelType.LLAVA,
- "What's in this image?",
- List.of(
- new File("/path/to/image")));
- System.out.println(response);
- }
+ OllamaResult response = ollamaAPI.askWithImageFiles(OllamaModelType.LLAVA,
+ "What's in this image?",
+ List.of(
+ new File("/path/to/image")));
+ System.out.println(response);
+ }
}
```
@@ -370,17 +370,17 @@ With URLs:
```java
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
- ollamaAPI.setRequestTimeoutSeconds(10);
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
+ ollamaAPI.setRequestTimeoutSeconds(10);
- OllamaResult response = ollamaAPI.askWithImageURLs(OllamaModelType.LLAVA,
- "What's in this image?",
- List.of(
- "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"));
- System.out.println(response);
- }
+ OllamaResult response = ollamaAPI.askWithImageURLs(OllamaModelType.LLAVA,
+ "What's in this image?",
+ List.of(
+ "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"));
+ System.out.println(response);
+ }
}
```
@@ -398,21 +398,21 @@ The dog seems to be enjoying its time outdoors, perhaps on a lake.
@SuppressWarnings("ALL")
public class Main {
- public static void main(String[] args) {
- String host = "http://localhost:11434/";
- OllamaAPI ollamaAPI = new OllamaAPI(host);
+ public static void main(String[] args) {
+ String host = "http://localhost:11434/";
+ OllamaAPI ollamaAPI = new OllamaAPI(host);
- String prompt = "List all cricket world cup teams of 2019.";
- OllamaAsyncResultCallback callback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, prompt);
- while (!callback.isComplete() || !callback.getStream().isEmpty()) {
- // poll for data from the response stream
- String response = callback.getStream().poll();
- if (response != null) {
- System.out.print(response);
- }
- Thread.sleep(1000);
+ String prompt = "List all cricket world cup teams of 2019.";
+ OllamaAsyncResultCallback callback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, prompt);
+ while (!callback.isComplete() || !callback.getStream().isEmpty()) {
+ // poll for data from the response stream
+ String response = callback.getStream().poll();
+ if (response != null) {
+ System.out.print(response);
+ }
+ Thread.sleep(1000);
+ }
}
- }
}
```
@@ -452,8 +452,8 @@ make it
- [x] Fix deprecated HTTP client code
- [x] Setup logging
- [x] Use lombok
-- [ ] Update request body creation with Java objects
-- [ ] Async APIs for images
+- [x] Update request body creation with Java objects
+- [ ] Async APIs for images
- [ ] Add additional params for `ask` APIs such as:
- `options`: additional model parameters for the Modelfile such as `temperature`
- `system`: system prompt to (overrides what is defined in the Modelfile)
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
index 066460f..a64e7ef 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
@@ -2,6 +2,10 @@ package io.github.amithkoujalgi.ollama4j.core;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.*;
+import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
+import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
+import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest;
+import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
@@ -17,7 +21,6 @@ import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
-import java.nio.file.Path;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
@@ -25,9 +28,7 @@ import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-/**
- * The base Ollama API class.
- */
+/** The base Ollama API class. */
@SuppressWarnings("DuplicatedCode")
public class OllamaAPI {
@@ -71,15 +72,21 @@ public class OllamaAPI {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = this.host + "/api/tags";
HttpClient httpClient = HttpClient.newHttpClient();
- HttpRequest httpRequest = HttpRequest.newBuilder().uri(new URI(url))
- .header("Accept", "application/json").header("Content-type", "application/json")
- .timeout(Duration.ofSeconds(requestTimeoutSeconds)).GET().build();
- HttpResponse response = httpClient.send(httpRequest,
- HttpResponse.BodyHandlers.ofString());
+ HttpRequest httpRequest =
+ HttpRequest.newBuilder()
+ .uri(new URI(url))
+ .header("Accept", "application/json")
+ .header("Content-type", "application/json")
+ .timeout(Duration.ofSeconds(requestTimeoutSeconds))
+ .GET()
+ .build();
+ HttpResponse response =
+ httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
if (statusCode == 200) {
- return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class)
+ return Utils.getObjectMapper()
+ .readValue(responseString, ListModelsResponse.class)
.getModels();
} else {
throw new OllamaBaseException(statusCode + " - " + responseString);
@@ -90,28 +97,32 @@ public class OllamaAPI {
* Pull a model on the Ollama server from the list of available models.
*
- * @param model the name of the model
+ * @param modelName the name of the model
*/
- public void pullModel(String model)
+ public void pullModel(String modelName)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String url = this.host + "/api/pull";
- String jsonData = String.format("{\"name\": \"%s\"}", model);
- HttpRequest request = HttpRequest.newBuilder().uri(new URI(url))
- .POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json")
- .header("Content-type", "application/json")
- .timeout(Duration.ofSeconds(requestTimeoutSeconds)).build();
+ String jsonData = new ModelRequest(modelName).toString();
+ HttpRequest request =
+ HttpRequest.newBuilder()
+ .uri(new URI(url))
+ .POST(HttpRequest.BodyPublishers.ofString(jsonData))
+ .header("Accept", "application/json")
+ .header("Content-type", "application/json")
+ .timeout(Duration.ofSeconds(requestTimeoutSeconds))
+ .build();
HttpClient client = HttpClient.newHttpClient();
- HttpResponse response = client.send(request,
- HttpResponse.BodyHandlers.ofInputStream());
+ HttpResponse response =
+ client.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
String responseString = "";
- try (BufferedReader reader = new BufferedReader(
- new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
+ try (BufferedReader reader =
+ new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
- ModelPullResponse modelPullResponse = Utils.getObjectMapper()
- .readValue(line, ModelPullResponse.class);
+ ModelPullResponse modelPullResponse =
+ Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
if (verbose) {
logger.info(modelPullResponse.getStatus());
}
@@ -131,11 +142,15 @@ public class OllamaAPI {
public ModelDetail getModelDetails(String modelName)
throws IOException, OllamaBaseException, InterruptedException {
String url = this.host + "/api/show";
- String jsonData = String.format("{\"name\": \"%s\"}", modelName);
- HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url))
- .header("Accept", "application/json").header("Content-type", "application/json")
- .timeout(Duration.ofSeconds(requestTimeoutSeconds))
- .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
+ String jsonData = new ModelRequest(modelName).toString();
+ HttpRequest request =
+ HttpRequest.newBuilder()
+ .uri(URI.create(url))
+ .header("Accept", "application/json")
+ .header("Content-type", "application/json")
+ .timeout(Duration.ofSeconds(requestTimeoutSeconds))
+ .POST(HttpRequest.BodyPublishers.ofString(jsonData))
+ .build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
@@ -151,18 +166,21 @@ public class OllamaAPI {
* Create a custom model from a model file. Read more about custom model file creation here.
*
- * @param modelName the name of the custom model to be created.
+ * @param modelName the name of the custom model to be created.
* @param modelFilePath the path to model file that exists on the Ollama server.
*/
- public void createModel(String modelName, String modelFilePath)
+ public void createModelWithFilePath(String modelName, String modelFilePath)
throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/create";
- String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName,
- modelFilePath);
- HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url))
- .header("Accept", "application/json").header("Content-Type", "application/json")
- .timeout(Duration.ofSeconds(requestTimeoutSeconds))
- .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
+ String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString();
+ HttpRequest request =
+ HttpRequest.newBuilder()
+ .uri(URI.create(url))
+ .header("Accept", "application/json")
+ .header("Content-Type", "application/json")
+ .timeout(Duration.ofSeconds(requestTimeoutSeconds))
+ .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
+ .build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
@@ -180,21 +198,59 @@ public class OllamaAPI {
}
}
+ /**
+ * Create a custom model from a model file. Read more about custom model file creation here.
+ *
+ * @param modelName the name of the custom model to be created.
+ * @param modelFileContents the path to model file that exists on the Ollama server.
+ */
+ public void createModelWithModelFileContents(String modelName, String modelFileContents)
+ throws IOException, InterruptedException, OllamaBaseException {
+ String url = this.host + "/api/create";
+ String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString();
+ HttpRequest request =
+ HttpRequest.newBuilder()
+ .uri(URI.create(url))
+ .header("Accept", "application/json")
+ .header("Content-Type", "application/json")
+ .timeout(Duration.ofSeconds(requestTimeoutSeconds))
+ .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
+ .build();
+ HttpClient client = HttpClient.newHttpClient();
+ HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString());
+ int statusCode = response.statusCode();
+ String responseString = response.body();
+ if (statusCode != 200) {
+ throw new OllamaBaseException(statusCode + " - " + responseString);
+ }
+ if (responseString.contains("error")) {
+ throw new OllamaBaseException(responseString);
+ }
+ if (verbose) {
+ logger.info(responseString);
+ }
+ }
+
/**
* Delete a model from Ollama server.
*
- * @param name the name of the model to be deleted.
+ * @param modelName the name of the model to be deleted.
* @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama
- * server.
+ * server.
*/
- public void deleteModel(String name, boolean ignoreIfNotPresent)
+ public void deleteModel(String modelName, boolean ignoreIfNotPresent)
throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/delete";
- String jsonData = String.format("{\"name\": \"%s\"}", name);
- HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url))
- .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
- .header("Accept", "application/json").header("Content-type", "application/json")
- .timeout(Duration.ofSeconds(requestTimeoutSeconds)).build();
+ String jsonData = new ModelRequest(modelName).toString();
+ HttpRequest request =
+ HttpRequest.newBuilder()
+ .uri(URI.create(url))
+ .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
+ .header("Accept", "application/json")
+ .header("Content-type", "application/json")
+ .timeout(Duration.ofSeconds(requestTimeoutSeconds))
+ .build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
@@ -210,25 +266,29 @@ public class OllamaAPI {
/**
* Generate embeddings for a given text from a model
*
- * @param model name of model to generate embeddings from
+ * @param model name of model to generate embeddings from
* @param prompt text to generate embeddings for
* @return embeddings
*/
public List generateEmbeddings(String model, String prompt)
throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/embeddings";
- String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt);
+ String jsonData = new ModelEmbeddingsRequest(model, prompt).toString();
HttpClient httpClient = HttpClient.newHttpClient();
- HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url))
- .header("Accept", "application/json").header("Content-type", "application/json")
- .timeout(Duration.ofSeconds(requestTimeoutSeconds))
- .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
+ HttpRequest request =
+ HttpRequest.newBuilder()
+ .uri(URI.create(url))
+ .header("Accept", "application/json")
+ .header("Content-type", "application/json")
+ .timeout(Duration.ofSeconds(requestTimeoutSeconds))
+ .POST(HttpRequest.BodyPublishers.ofString(jsonData))
+ .build();
HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseBody = response.body();
if (statusCode == 200) {
- EmbeddingResponse embeddingResponse = Utils.getObjectMapper()
- .readValue(responseBody, EmbeddingResponse.class);
+ EmbeddingResponse embeddingResponse =
+ Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class);
return embeddingResponse.getEmbedding();
} else {
throw new OllamaBaseException(statusCode + " - " + responseBody);
@@ -238,7 +298,7 @@ public class OllamaAPI {
/**
* Ask a question to a model running on Ollama server. This is a sync/blocking call.
*
- * @param model the ollama model to ask the question to
+ * @param model the ollama model to ask the question to
* @param promptText the prompt/question text
* @return OllamaResult - that includes response text and time taken for response
*/
@@ -248,11 +308,30 @@ public class OllamaAPI {
return askSync(ollamaRequestModel);
}
+ /**
+ * Ask a question to a model running on Ollama server and get a callback handle that can be used
+ * to check for status and get the response from the model later. This would be an
+ * async/non-blocking call.
+ *
+ * @param model the ollama model to ask the question to
+ * @param promptText the prompt/question text
+ * @return the ollama async result callback handle
+ */
+ public OllamaAsyncResultCallback askAsync(String model, String promptText) {
+ OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText);
+ HttpClient httpClient = HttpClient.newHttpClient();
+ URI uri = URI.create(this.host + "/api/generate");
+ OllamaAsyncResultCallback ollamaAsyncResultCallback =
+ new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel, requestTimeoutSeconds);
+ ollamaAsyncResultCallback.start();
+ return ollamaAsyncResultCallback;
+ }
+
/**
* With one or more image files, ask a question to a model running on Ollama server. This is a
* sync/blocking call.
*
- * @param model the ollama model to ask the question to
+ * @param model the ollama model to ask the question to
* @param promptText the prompt/question text
* @param imageFiles the list of image files to use for the question
* @return OllamaResult - that includes response text and time taken for response
@@ -271,9 +350,9 @@ public class OllamaAPI {
* With one or more image URLs, ask a question to a model running on Ollama server. This is a
* sync/blocking call.
*
- * @param model the ollama model to ask the question to
+ * @param model the ollama model to ask the question to
* @param promptText the prompt/question text
- * @param imageURLs the list of image URLs to use for the question
+ * @param imageURLs the list of image URLs to use for the question
* @return OllamaResult - that includes response text and time taken for response
*/
public OllamaResult askWithImageURLs(String model, String promptText, List imageURLs)
@@ -286,18 +365,19 @@ public class OllamaAPI {
return askSync(ollamaRequestModel);
}
- public static String encodeFileToBase64(File file) throws IOException {
+ private static String encodeFileToBase64(File file) throws IOException {
return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
}
- public static String encodeByteArrayToBase64(byte[] bytes) {
+ private static String encodeByteArrayToBase64(byte[] bytes) {
return Base64.getEncoder().encodeToString(bytes);
}
- public static byte[] loadImageBytesFromUrl(String imageUrl)
+ private static byte[] loadImageBytesFromUrl(String imageUrl)
throws IOException, URISyntaxException {
URL url = new URI(imageUrl).toURL();
- try (InputStream in = url.openStream(); ByteArrayOutputStream out = new ByteArrayOutputStream()) {
+ try (InputStream in = url.openStream();
+ ByteArrayOutputStream out = new ByteArrayOutputStream()) {
byte[] buffer = new byte[1024];
int bytesRead;
while ((bytesRead = in.read(buffer)) != -1) {
@@ -307,50 +387,35 @@ public class OllamaAPI {
}
}
- /**
- * Ask a question to a model running on Ollama server and get a callback handle that can be used
- * to check for status and get the response from the model later. This would be an
- * async/non-blocking call.
- *
- * @param model the ollama model to ask the question to
- * @param promptText the prompt/question text
- * @return the ollama async result callback handle
- */
- public OllamaAsyncResultCallback askAsync(String model, String promptText) {
- OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText);
- HttpClient httpClient = HttpClient.newHttpClient();
- URI uri = URI.create(this.host + "/api/generate");
- OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(httpClient,
- uri, ollamaRequestModel, requestTimeoutSeconds);
- ollamaAsyncResultCallback.start();
- return ollamaAsyncResultCallback;
- }
-
private OllamaResult askSync(OllamaRequestModel ollamaRequestModel)
throws OllamaBaseException, IOException, InterruptedException {
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate");
- HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(
- Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
- .header("Content-Type", "application/json")
- .timeout(Duration.ofSeconds(requestTimeoutSeconds)).build();
- HttpResponse response = httpClient.send(request,
- HttpResponse.BodyHandlers.ofInputStream());
+ HttpRequest request =
+ HttpRequest.newBuilder(uri)
+ .POST(
+ HttpRequest.BodyPublishers.ofString(
+ Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
+ .header("Content-Type", "application/json")
+ .timeout(Duration.ofSeconds(requestTimeoutSeconds))
+ .build();
+ HttpResponse response =
+ httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
- try (BufferedReader reader = new BufferedReader(
- new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
+ try (BufferedReader reader =
+ new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
- OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper()
- .readValue(line, OllamaErrorResponseModel.class);
+ OllamaErrorResponseModel ollamaResponseModel =
+ Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
- OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper()
- .readValue(line, OllamaResponseModel.class);
+ OllamaResponseModel ollamaResponseModel =
+ Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
if (!ollamaResponseModel.isDone()) {
responseBuffer.append(ollamaResponseModel.getResponse());
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java
index 20cf85e..fa557c6 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java
@@ -1,14 +1,20 @@
package io.github.amithkoujalgi.ollama4j.core.models;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
+import java.util.Map;
import lombok.Data;
@Data
+@JsonIgnoreProperties(ignoreUnknown = true)
public class ModelDetail {
private String license;
+
@JsonProperty("modelfile")
private String modelFile;
+
private String parameters;
private String template;
private String system;
+ private Map details;
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFileContentsRequest.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFileContentsRequest.java
new file mode 100644
index 0000000..9e606d3
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFileContentsRequest.java
@@ -0,0 +1,23 @@
+package io.github.amithkoujalgi.ollama4j.core.models.request;
+
+import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+
+@Data
+@AllArgsConstructor
+public class CustomModelFileContentsRequest {
+ private String name;
+ private String modelfile;
+
+ @Override
+ public String toString() {
+ try {
+ return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFilePathRequest.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFilePathRequest.java
new file mode 100644
index 0000000..ea08dbf
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/CustomModelFilePathRequest.java
@@ -0,0 +1,23 @@
+package io.github.amithkoujalgi.ollama4j.core.models.request;
+
+import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+
+@Data
+@AllArgsConstructor
+public class CustomModelFilePathRequest {
+ private String name;
+ private String path;
+
+ @Override
+ public String toString() {
+ try {
+ return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java
new file mode 100644
index 0000000..1455a94
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java
@@ -0,0 +1,23 @@
+package io.github.amithkoujalgi.ollama4j.core.models.request;
+
+import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+
+@Data
+@AllArgsConstructor
+public class ModelEmbeddingsRequest {
+ private String model;
+ private String prompt;
+
+ @Override
+ public String toString() {
+ try {
+ return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelRequest.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelRequest.java
new file mode 100644
index 0000000..d3fdec4
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelRequest.java
@@ -0,0 +1,22 @@
+package io.github.amithkoujalgi.ollama4j.core.models.request;
+
+import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+
+@Data
+@AllArgsConstructor
+public class ModelRequest {
+ private String name;
+
+ @Override
+ public String toString() {
+ try {
+ return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java
index 3b5fafc..7c46e37 100644
--- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java
@@ -44,11 +44,11 @@ class TestMockedAPIs {
void testCreateModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
- String modelFilePath = "/somemodel";
+ String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
try {
- doNothing().when(ollamaAPI).createModel(model, modelFilePath);
- ollamaAPI.createModel(model, modelFilePath);
- verify(ollamaAPI, times(1)).createModel(model, modelFilePath);
+ doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
+ ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
+ verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}