mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 20:07:10 +02:00
Merge remote-tracking branch 'origin/main'
This commit is contained in:
commit
f6af7025a8
@ -91,7 +91,13 @@ For simplest way to get started, I prefer to use the Ollama docker setup.
|
|||||||
Start the Ollama docker container:
|
Start the Ollama docker container:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
docker run -v ~/ollama:/root/.ollama -p 11434:11434 ollama/ollama
|
docker run -it -v ~/ollama:/root/.ollama -p 11434:11434 ollama/ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
With GPUs
|
||||||
|
|
||||||
|
```shell
|
||||||
|
docker run -it --gpus=all -v ~/ollama:/root/.ollama -p 11434:11434 ollama/ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
Instantiate `OllamaAPI`
|
Instantiate `OllamaAPI`
|
||||||
@ -347,6 +353,7 @@ Find the full `Javadoc` (API specifications) [here](https://amithkoujalgi.github
|
|||||||
conversational memory
|
conversational memory
|
||||||
- `stream`: Add support for streaming responses from the model
|
- `stream`: Add support for streaming responses from the model
|
||||||
- [x] Setup logging
|
- [x] Setup logging
|
||||||
|
- [ ] Use lombok
|
||||||
- [ ] Add test cases
|
- [ ] Add test cases
|
||||||
- [ ] Handle exceptions better (maybe throw more appropriate exceptions)
|
- [ ] Handle exceptions better (maybe throw more appropriate exceptions)
|
||||||
|
|
||||||
|
@ -3,9 +3,6 @@ package io.github.amithkoujalgi.ollama4j.core;
|
|||||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.*;
|
import io.github.amithkoujalgi.ollama4j.core.models.*;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
@ -17,226 +14,298 @@ import java.net.http.HttpRequest;
|
|||||||
import java.net.http.HttpResponse;
|
import java.net.http.HttpResponse;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
/**
|
/** The base Ollama API class. */
|
||||||
* The base Ollama API class.
|
|
||||||
*/
|
|
||||||
@SuppressWarnings("DuplicatedCode")
|
@SuppressWarnings("DuplicatedCode")
|
||||||
public class OllamaAPI {
|
public class OllamaAPI {
|
||||||
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
|
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
|
||||||
private final String host;
|
private final String host;
|
||||||
private boolean verbose = false;
|
private boolean verbose = false;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Instantiates the Ollama API.
|
* Instantiates the Ollama API.
|
||||||
*
|
*
|
||||||
* @param host the host address of Ollama server
|
* @param host the host address of Ollama server
|
||||||
*/
|
*/
|
||||||
public OllamaAPI(String host) {
|
public OllamaAPI(String host) {
|
||||||
if (host.endsWith("/")) {
|
if (host.endsWith("/")) {
|
||||||
this.host = host.substring(0, host.length() - 1);
|
this.host = host.substring(0, host.length() - 1);
|
||||||
} else {
|
} else {
|
||||||
this.host = host;
|
this.host = host;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set/unset logging of responses
|
* Set/unset logging of responses
|
||||||
* @param verbose true/false
|
*
|
||||||
*/
|
* @param verbose true/false
|
||||||
public void setVerbose(boolean verbose) {
|
*/
|
||||||
this.verbose = verbose;
|
public void setVerbose(boolean verbose) {
|
||||||
|
this.verbose = verbose;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List available models from Ollama server.
|
||||||
|
*
|
||||||
|
* @return the list
|
||||||
|
*/
|
||||||
|
public List<Model> listModels()
|
||||||
|
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||||
|
String url = this.host + "/api/tags";
|
||||||
|
HttpClient httpClient = HttpClient.newHttpClient();
|
||||||
|
HttpRequest httpRequest =
|
||||||
|
HttpRequest.newBuilder()
|
||||||
|
.uri(new URI(url))
|
||||||
|
.header("Accept", "application/json")
|
||||||
|
.header("Content-type", "application/json")
|
||||||
|
.GET()
|
||||||
|
.build();
|
||||||
|
HttpResponse<String> response =
|
||||||
|
httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
||||||
|
int statusCode = response.statusCode();
|
||||||
|
String responseString = response.body();
|
||||||
|
if (statusCode == 200) {
|
||||||
|
return Utils.getObjectMapper()
|
||||||
|
.readValue(responseString, ListModelsResponse.class)
|
||||||
|
.getModels();
|
||||||
|
} else {
|
||||||
|
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List available models from Ollama server.
|
* Pull a model on the Ollama server from the list of <a
|
||||||
*
|
* href="https://ollama.ai/library">available models</a>.
|
||||||
* @return the list
|
*
|
||||||
*/
|
* @param model the name of the model
|
||||||
public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
*/
|
||||||
String url = this.host + "/api/tags";
|
public void pullModel(String model)
|
||||||
HttpClient httpClient = HttpClient.newHttpClient();
|
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||||
HttpRequest httpRequest = HttpRequest.newBuilder().uri(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
|
String url = this.host + "/api/pull";
|
||||||
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
String jsonData = String.format("{\"name\": \"%s\"}", model);
|
||||||
int statusCode = response.statusCode();
|
HttpRequest request =
|
||||||
String responseString = response.body();
|
HttpRequest.newBuilder()
|
||||||
if (statusCode == 200) {
|
.uri(new URI(url))
|
||||||
return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class).getModels();
|
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
||||||
} else {
|
.header("Accept", "application/json")
|
||||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
.header("Content-type", "application/json")
|
||||||
}
|
.build();
|
||||||
}
|
HttpClient client = HttpClient.newHttpClient();
|
||||||
|
HttpResponse<InputStream> response =
|
||||||
/**
|
client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||||
* Pull a model on the Ollama server from the list of <a href="https://ollama.ai/library">available models</a>.
|
int statusCode = response.statusCode();
|
||||||
*
|
InputStream responseBodyStream = response.body();
|
||||||
* @param model the name of the model
|
String responseString = "";
|
||||||
*/
|
try (BufferedReader reader =
|
||||||
public void pullModel(String model) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||||
String url = this.host + "/api/pull";
|
String line;
|
||||||
String jsonData = String.format("{\"name\": \"%s\"}", model);
|
while ((line = reader.readLine()) != null) {
|
||||||
HttpRequest request = HttpRequest.newBuilder().uri(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json").header("Content-type", "application/json").build();
|
ModelPullResponse modelPullResponse =
|
||||||
HttpClient client = HttpClient.newHttpClient();
|
Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
|
||||||
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
|
||||||
int statusCode = response.statusCode();
|
|
||||||
InputStream responseBodyStream = response.body();
|
|
||||||
String responseString = "";
|
|
||||||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
|
||||||
String line;
|
|
||||||
while ((line = reader.readLine()) != null) {
|
|
||||||
ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
|
|
||||||
if (verbose) {
|
|
||||||
logger.info(modelPullResponse.getStatus());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (statusCode != 200) {
|
|
||||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Gets model details from the Ollama server.
|
|
||||||
*
|
|
||||||
* @param modelName the model
|
|
||||||
* @return the model details
|
|
||||||
*/
|
|
||||||
public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException {
|
|
||||||
String url = this.host + "/api/show";
|
|
||||||
String jsonData = String.format("{\"name\": \"%s\"}", modelName);
|
|
||||||
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
|
|
||||||
HttpClient client = HttpClient.newHttpClient();
|
|
||||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
|
||||||
int statusCode = response.statusCode();
|
|
||||||
String responseBody = response.body();
|
|
||||||
if (statusCode == 200) {
|
|
||||||
return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
|
|
||||||
} else {
|
|
||||||
throw new OllamaBaseException(statusCode + " - " + responseBody);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a custom model from a model file.
|
|
||||||
* Read more about custom model file creation <a href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md">here</a>.
|
|
||||||
*
|
|
||||||
* @param modelName the name of the custom model to be created.
|
|
||||||
* @param modelFilePath the path to model file that exists on the Ollama server.
|
|
||||||
*/
|
|
||||||
public void createModel(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException {
|
|
||||||
String url = this.host + "/api/create";
|
|
||||||
String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath);
|
|
||||||
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
|
|
||||||
HttpClient client = HttpClient.newHttpClient();
|
|
||||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
|
||||||
int statusCode = response.statusCode();
|
|
||||||
String responseString = response.body();
|
|
||||||
if (statusCode != 200) {
|
|
||||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
|
||||||
}
|
|
||||||
// FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this if the issue is fixed in the Ollama API server.
|
|
||||||
if (responseString.contains("error")) {
|
|
||||||
throw new OllamaBaseException(responseString);
|
|
||||||
}
|
|
||||||
if (verbose) {
|
if (verbose) {
|
||||||
logger.info(responseString);
|
logger.info(modelPullResponse.getStatus());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if (statusCode != 200) {
|
||||||
/**
|
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||||
* Delete a model from Ollama server.
|
|
||||||
*
|
|
||||||
* @param name the name of the model to be deleted.
|
|
||||||
* @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama server.
|
|
||||||
*/
|
|
||||||
public void deleteModel(String name, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException {
|
|
||||||
String url = this.host + "/api/delete";
|
|
||||||
String jsonData = String.format("{\"name\": \"%s\"}", name);
|
|
||||||
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build();
|
|
||||||
HttpClient client = HttpClient.newHttpClient();
|
|
||||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
|
||||||
int statusCode = response.statusCode();
|
|
||||||
String responseBody = response.body();
|
|
||||||
if (statusCode == 404 && responseBody.contains("model") && responseBody.contains("not found")) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (statusCode != 200) {
|
|
||||||
throw new OllamaBaseException(statusCode + " - " + responseBody);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate embeddings for a given text from a model
|
* Gets model details from the Ollama server.
|
||||||
*
|
*
|
||||||
* @param model name of model to generate embeddings from
|
* @param modelName the model
|
||||||
* @param prompt text to generate embeddings for
|
* @return the model details
|
||||||
* @return embeddings
|
*/
|
||||||
*/
|
public ModelDetail getModelDetails(String modelName)
|
||||||
public List<Double> generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException {
|
throws IOException, OllamaBaseException, InterruptedException {
|
||||||
String url = this.host + "/api/embeddings";
|
String url = this.host + "/api/show";
|
||||||
String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt);
|
String jsonData = String.format("{\"name\": \"%s\"}", modelName);
|
||||||
HttpClient httpClient = HttpClient.newHttpClient();
|
HttpRequest request =
|
||||||
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
|
HttpRequest.newBuilder()
|
||||||
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
|
.uri(URI.create(url))
|
||||||
int statusCode = response.statusCode();
|
.header("Accept", "application/json")
|
||||||
String responseBody = response.body();
|
.header("Content-type", "application/json")
|
||||||
if (statusCode == 200) {
|
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
||||||
EmbeddingResponse embeddingResponse = Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class);
|
.build();
|
||||||
return embeddingResponse.getEmbedding();
|
HttpClient client = HttpClient.newHttpClient();
|
||||||
|
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||||
|
int statusCode = response.statusCode();
|
||||||
|
String responseBody = response.body();
|
||||||
|
if (statusCode == 200) {
|
||||||
|
return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
|
||||||
|
} else {
|
||||||
|
throw new OllamaBaseException(statusCode + " - " + responseBody);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a custom model from a model file. Read more about custom model file creation <a
|
||||||
|
* href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md">here</a>.
|
||||||
|
*
|
||||||
|
* @param modelName the name of the custom model to be created.
|
||||||
|
* @param modelFilePath the path to model file that exists on the Ollama server.
|
||||||
|
*/
|
||||||
|
public void createModel(String modelName, String modelFilePath)
|
||||||
|
throws IOException, InterruptedException, OllamaBaseException {
|
||||||
|
String url = this.host + "/api/create";
|
||||||
|
String jsonData =
|
||||||
|
String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath);
|
||||||
|
HttpRequest request =
|
||||||
|
HttpRequest.newBuilder()
|
||||||
|
.uri(URI.create(url))
|
||||||
|
.header("Accept", "application/json")
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
|
||||||
|
.build();
|
||||||
|
HttpClient client = HttpClient.newHttpClient();
|
||||||
|
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||||
|
int statusCode = response.statusCode();
|
||||||
|
String responseString = response.body();
|
||||||
|
if (statusCode != 200) {
|
||||||
|
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||||
|
}
|
||||||
|
// FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this
|
||||||
|
// if the issue is fixed in the Ollama API server.
|
||||||
|
if (responseString.contains("error")) {
|
||||||
|
throw new OllamaBaseException(responseString);
|
||||||
|
}
|
||||||
|
if (verbose) {
|
||||||
|
logger.info(responseString);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Delete a model from Ollama server.
|
||||||
|
*
|
||||||
|
* @param name the name of the model to be deleted.
|
||||||
|
* @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama
|
||||||
|
* server.
|
||||||
|
*/
|
||||||
|
public void deleteModel(String name, boolean ignoreIfNotPresent)
|
||||||
|
throws IOException, InterruptedException, OllamaBaseException {
|
||||||
|
String url = this.host + "/api/delete";
|
||||||
|
String jsonData = String.format("{\"name\": \"%s\"}", name);
|
||||||
|
HttpRequest request =
|
||||||
|
HttpRequest.newBuilder()
|
||||||
|
.uri(URI.create(url))
|
||||||
|
.method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
|
||||||
|
.header("Accept", "application/json")
|
||||||
|
.header("Content-type", "application/json")
|
||||||
|
.build();
|
||||||
|
HttpClient client = HttpClient.newHttpClient();
|
||||||
|
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||||
|
int statusCode = response.statusCode();
|
||||||
|
String responseBody = response.body();
|
||||||
|
if (statusCode == 404 && responseBody.contains("model") && responseBody.contains("not found")) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (statusCode != 200) {
|
||||||
|
throw new OllamaBaseException(statusCode + " - " + responseBody);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate embeddings for a given text from a model
|
||||||
|
*
|
||||||
|
* @param model name of model to generate embeddings from
|
||||||
|
* @param prompt text to generate embeddings for
|
||||||
|
* @return embeddings
|
||||||
|
*/
|
||||||
|
public List<Double> generateEmbeddings(String model, String prompt)
|
||||||
|
throws IOException, InterruptedException, OllamaBaseException {
|
||||||
|
String url = this.host + "/api/embeddings";
|
||||||
|
String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt);
|
||||||
|
HttpClient httpClient = HttpClient.newHttpClient();
|
||||||
|
HttpRequest request =
|
||||||
|
HttpRequest.newBuilder()
|
||||||
|
.uri(URI.create(url))
|
||||||
|
.header("Accept", "application/json")
|
||||||
|
.header("Content-type", "application/json")
|
||||||
|
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
||||||
|
.build();
|
||||||
|
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
|
||||||
|
int statusCode = response.statusCode();
|
||||||
|
String responseBody = response.body();
|
||||||
|
if (statusCode == 200) {
|
||||||
|
EmbeddingResponse embeddingResponse =
|
||||||
|
Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class);
|
||||||
|
return embeddingResponse.getEmbedding();
|
||||||
|
} else {
|
||||||
|
throw new OllamaBaseException(statusCode + " - " + responseBody);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ask a question to a model running on Ollama server. This is a sync/blocking call.
|
||||||
|
*
|
||||||
|
* @param ollamaModelType the ollama model to ask the question to
|
||||||
|
* @param promptText the prompt/question text
|
||||||
|
* @return OllamaResult - that includes response text and time taken for response
|
||||||
|
*/
|
||||||
|
public OllamaResult ask(String ollamaModelType, String promptText)
|
||||||
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
|
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
HttpClient httpClient = HttpClient.newHttpClient();
|
||||||
|
URI uri = URI.create(this.host + "/api/generate");
|
||||||
|
HttpRequest request =
|
||||||
|
HttpRequest.newBuilder(uri)
|
||||||
|
.POST(
|
||||||
|
HttpRequest.BodyPublishers.ofString(
|
||||||
|
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.build();
|
||||||
|
HttpResponse<InputStream> response =
|
||||||
|
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||||
|
int statusCode = response.statusCode();
|
||||||
|
InputStream responseBodyStream = response.body();
|
||||||
|
StringBuilder responseBuffer = new StringBuilder();
|
||||||
|
try (BufferedReader reader =
|
||||||
|
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||||
|
String line;
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
if (statusCode == 404) {
|
||||||
|
OllamaErrorResponseModel ollamaResponseModel =
|
||||||
|
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
} else {
|
} else {
|
||||||
throw new OllamaBaseException(statusCode + " - " + responseBody);
|
OllamaResponseModel ollamaResponseModel =
|
||||||
|
Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
|
||||||
|
if (!ollamaResponseModel.getDone()) {
|
||||||
|
responseBuffer.append(ollamaResponseModel.getResponse());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if (statusCode != 200) {
|
||||||
|
throw new OllamaBaseException(responseBuffer.toString());
|
||||||
|
} else {
|
||||||
|
long endTime = System.currentTimeMillis();
|
||||||
|
return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Ask a question to a model running on Ollama server. This is a sync/blocking call.
|
* Ask a question to a model running on Ollama server and get a callback handle that can be used
|
||||||
*
|
* to check for status and get the response from the model later. This would be an
|
||||||
* @param ollamaModelType the ollama model to ask the question to
|
* async/non-blocking call.
|
||||||
* @param promptText the prompt/question text
|
*
|
||||||
* @return OllamaResult - that includes response text and time taken for response
|
* @param ollamaModelType the ollama model to ask the question to
|
||||||
*/
|
* @param promptText the prompt/question text
|
||||||
public OllamaResult ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException {
|
* @return the ollama async result callback handle
|
||||||
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
|
*/
|
||||||
long startTime = System.currentTimeMillis();
|
public OllamaAsyncResultCallback askAsync(String ollamaModelType, String promptText) {
|
||||||
HttpClient httpClient = HttpClient.newHttpClient();
|
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
|
||||||
URI uri = URI.create(this.host + "/api/generate");
|
HttpClient httpClient = HttpClient.newHttpClient();
|
||||||
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build();
|
URI uri = URI.create(this.host + "/api/generate");
|
||||||
HttpResponse<InputStream> response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
OllamaAsyncResultCallback ollamaAsyncResultCallback =
|
||||||
int statusCode = response.statusCode();
|
new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel);
|
||||||
InputStream responseBodyStream = response.body();
|
ollamaAsyncResultCallback.start();
|
||||||
StringBuilder responseBuffer = new StringBuilder();
|
return ollamaAsyncResultCallback;
|
||||||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
}
|
||||||
String line;
|
|
||||||
while ((line = reader.readLine()) != null) {
|
|
||||||
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
|
|
||||||
if (!ollamaResponseModel.getDone()) {
|
|
||||||
responseBuffer.append(ollamaResponseModel.getResponse());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (statusCode != 200) {
|
|
||||||
throw new OllamaBaseException(statusCode + " - " + responseBuffer);
|
|
||||||
} else {
|
|
||||||
long endTime = System.currentTimeMillis();
|
|
||||||
return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Ask a question to a model running on Ollama server and get a callback handle that can be used to check for status and get the response from the model later.
|
|
||||||
* This would be an async/non-blocking call.
|
|
||||||
*
|
|
||||||
* @param ollamaModelType the ollama model to ask the question to
|
|
||||||
* @param promptText the prompt/question text
|
|
||||||
* @return the ollama async result callback handle
|
|
||||||
*/
|
|
||||||
public OllamaAsyncResultCallback askAsync(String ollamaModelType, String promptText) {
|
|
||||||
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
|
|
||||||
HttpClient httpClient = HttpClient.newHttpClient();
|
|
||||||
URI uri = URI.create(this.host + "/api/generate");
|
|
||||||
OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel);
|
|
||||||
ollamaAsyncResultCallback.start();
|
|
||||||
return ollamaAsyncResultCallback;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package io.github.amithkoujalgi.ollama4j.core.models;
|
|||||||
|
|
||||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||||
|
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
@ -17,79 +16,129 @@ import java.util.Queue;
|
|||||||
|
|
||||||
@SuppressWarnings("unused")
|
@SuppressWarnings("unused")
|
||||||
public class OllamaAsyncResultCallback extends Thread {
|
public class OllamaAsyncResultCallback extends Thread {
|
||||||
private final HttpClient client;
|
private final HttpClient client;
|
||||||
private final URI uri;
|
private final URI uri;
|
||||||
private final OllamaRequestModel ollamaRequestModel;
|
private final OllamaRequestModel ollamaRequestModel;
|
||||||
private final Queue<String> queue = new LinkedList<>();
|
private final Queue<String> queue = new LinkedList<>();
|
||||||
private String result;
|
private String result;
|
||||||
private boolean isDone;
|
private boolean isDone;
|
||||||
private long responseTime = 0;
|
private boolean succeeded;
|
||||||
|
|
||||||
public OllamaAsyncResultCallback(HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) {
|
private int httpStatusCode;
|
||||||
this.client = client;
|
private long responseTime = 0;
|
||||||
this.ollamaRequestModel = ollamaRequestModel;
|
|
||||||
this.uri = uri;
|
|
||||||
this.isDone = false;
|
|
||||||
this.result = "";
|
|
||||||
this.queue.add("");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
public OllamaAsyncResultCallback(
|
||||||
public void run() {
|
HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) {
|
||||||
try {
|
this.client = client;
|
||||||
long startTime = System.currentTimeMillis();
|
this.ollamaRequestModel = ollamaRequestModel;
|
||||||
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build();
|
this.uri = uri;
|
||||||
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
this.isDone = false;
|
||||||
int statusCode = response.statusCode();
|
this.result = "";
|
||||||
|
this.queue.add("");
|
||||||
|
}
|
||||||
|
|
||||||
InputStream responseBodyStream = response.body();
|
@Override
|
||||||
String responseString = "";
|
public void run() {
|
||||||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
try {
|
||||||
String line;
|
long startTime = System.currentTimeMillis();
|
||||||
StringBuilder responseBuffer = new StringBuilder();
|
HttpRequest request =
|
||||||
while ((line = reader.readLine()) != null) {
|
HttpRequest.newBuilder(uri)
|
||||||
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
|
.POST(
|
||||||
queue.add(ollamaResponseModel.getResponse());
|
HttpRequest.BodyPublishers.ofString(
|
||||||
if (!ollamaResponseModel.getDone()) {
|
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
|
||||||
responseBuffer.append(ollamaResponseModel.getResponse());
|
.header("Content-Type", "application/json")
|
||||||
}
|
.build();
|
||||||
}
|
HttpResponse<InputStream> response =
|
||||||
reader.close();
|
client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||||
this.isDone = true;
|
int statusCode = response.statusCode();
|
||||||
this.result = responseBuffer.toString();
|
this.httpStatusCode = statusCode;
|
||||||
long endTime = System.currentTimeMillis();
|
|
||||||
responseTime = endTime - startTime;
|
InputStream responseBodyStream = response.body();
|
||||||
|
try (BufferedReader reader =
|
||||||
|
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||||
|
String line;
|
||||||
|
StringBuilder responseBuffer = new StringBuilder();
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
if (statusCode == 404) {
|
||||||
|
OllamaErrorResponseModel ollamaResponseModel =
|
||||||
|
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
|
||||||
|
queue.add(ollamaResponseModel.getError());
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else {
|
||||||
|
OllamaResponseModel ollamaResponseModel =
|
||||||
|
Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
|
||||||
|
queue.add(ollamaResponseModel.getResponse());
|
||||||
|
if (!ollamaResponseModel.getDone()) {
|
||||||
|
responseBuffer.append(ollamaResponseModel.getResponse());
|
||||||
}
|
}
|
||||||
if (statusCode != 200) {
|
}
|
||||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
|
||||||
}
|
|
||||||
} catch (IOException | InterruptedException | OllamaBaseException e) {
|
|
||||||
this.isDone = true;
|
|
||||||
this.result = "FAILED! " + e.getMessage();
|
|
||||||
}
|
}
|
||||||
}
|
reader.close();
|
||||||
|
|
||||||
public boolean isComplete() {
|
this.isDone = true;
|
||||||
return isDone;
|
this.succeeded = true;
|
||||||
|
this.result = responseBuffer.toString();
|
||||||
|
long endTime = System.currentTimeMillis();
|
||||||
|
responseTime = endTime - startTime;
|
||||||
|
}
|
||||||
|
if (statusCode != 200) {
|
||||||
|
throw new OllamaBaseException(this.result);
|
||||||
|
}
|
||||||
|
} catch (IOException | InterruptedException | OllamaBaseException e) {
|
||||||
|
this.isDone = true;
|
||||||
|
this.succeeded = false;
|
||||||
|
this.result = "[FAILED] " + e.getMessage();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the final response when the execution completes. Does not return intermediate results.
|
* Returns the status of the thread. This does not indicate that the request was successful or a
|
||||||
* @return response text
|
* failure, rather it is just a status flag to indicate if the thread is active or ended.
|
||||||
*/
|
*
|
||||||
public String getResponse() {
|
* @return boolean - status
|
||||||
return result;
|
*/
|
||||||
}
|
public boolean isComplete() {
|
||||||
|
return isDone;
|
||||||
|
}
|
||||||
|
|
||||||
public Queue<String> getStream() {
|
/**
|
||||||
return queue;
|
* Returns the HTTP response status code for the request that was made to Ollama server.
|
||||||
}
|
*
|
||||||
|
* @return int - the status code for the request
|
||||||
|
*/
|
||||||
|
public int getHttpStatusCode() {
|
||||||
|
return httpStatusCode;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the response time in seconds.
|
* Returns the status of the request. Indicates if the request was successful or a failure. If the
|
||||||
* @return response time in seconds
|
* request was a failure, the `getResponse()` method will return the error message.
|
||||||
*/
|
*
|
||||||
public long getResponseTime() {
|
* @return boolean - status
|
||||||
return responseTime;
|
*/
|
||||||
}
|
public boolean isSucceeded() {
|
||||||
|
return succeeded;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the final response when the execution completes. Does not return intermediate results.
|
||||||
|
*
|
||||||
|
* @return String - response text
|
||||||
|
*/
|
||||||
|
public String getResponse() {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Queue<String> getStream() {
|
||||||
|
return queue;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the response time in milliseconds.
|
||||||
|
*
|
||||||
|
* @return long - response time in milliseconds.
|
||||||
|
*/
|
||||||
|
public long getResponseTime() {
|
||||||
|
return responseTime;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,18 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||||
|
public class OllamaErrorResponseModel {
|
||||||
|
private String error;
|
||||||
|
|
||||||
|
public String getError() {
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setError(String error) {
|
||||||
|
this.error = error;
|
||||||
|
}
|
||||||
|
}
|
@ -1,20 +1,57 @@
|
|||||||
package io.github.amithkoujalgi.ollama4j.core.models;
|
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||||
|
|
||||||
|
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
|
||||||
|
/** The type Ollama result. */
|
||||||
@SuppressWarnings("unused")
|
@SuppressWarnings("unused")
|
||||||
public class OllamaResult {
|
public class OllamaResult {
|
||||||
private String response;
|
private final String response;
|
||||||
private long responseTime = 0;
|
|
||||||
|
|
||||||
public OllamaResult(String response, long responseTime) {
|
private int httpStatusCode;
|
||||||
this.response = response;
|
|
||||||
this.responseTime = responseTime;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getResponse() {
|
private long responseTime = 0;
|
||||||
return response;
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getResponseTime() {
|
public OllamaResult(String response, long responseTime, int httpStatusCode) {
|
||||||
return responseTime;
|
this.response = response;
|
||||||
|
this.responseTime = responseTime;
|
||||||
|
this.httpStatusCode = httpStatusCode;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the response text
|
||||||
|
*
|
||||||
|
* @return String - response text
|
||||||
|
*/
|
||||||
|
public String getResponse() {
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the response time in milliseconds.
|
||||||
|
*
|
||||||
|
* @return long - response time in milliseconds
|
||||||
|
*/
|
||||||
|
public long getResponseTime() {
|
||||||
|
return responseTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the response status code.
|
||||||
|
*
|
||||||
|
* @return int - response status code
|
||||||
|
*/
|
||||||
|
public int getHttpStatusCode() {
|
||||||
|
return httpStatusCode;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
try {
|
||||||
|
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,121 +1,121 @@
|
|||||||
package io.github.amithkoujalgi.ollama4j.unittests;
|
package io.github.amithkoujalgi.ollama4j.unittests;
|
||||||
|
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
|
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
|
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
|
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
|
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
|
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.mockito.Mockito;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.URISyntaxException;
|
import java.net.URISyntaxException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
import static org.mockito.Mockito.*;
|
import org.mockito.Mockito;
|
||||||
|
|
||||||
public class TestMockedAPIs {
|
public class TestMockedAPIs {
|
||||||
@Test
|
@Test
|
||||||
public void testMockPullModel() {
|
public void testMockPullModel() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
try {
|
try {
|
||||||
doNothing().when(ollamaAPI).pullModel(model);
|
doNothing().when(ollamaAPI).pullModel(model);
|
||||||
ollamaAPI.pullModel(model);
|
ollamaAPI.pullModel(model);
|
||||||
verify(ollamaAPI, times(1)).pullModel(model);
|
verify(ollamaAPI, times(1)).pullModel(model);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testListModels() {
|
public void testListModels() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
|
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
|
||||||
ollamaAPI.listModels();
|
ollamaAPI.listModels();
|
||||||
verify(ollamaAPI, times(1)).listModels();
|
verify(ollamaAPI, times(1)).listModels();
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCreateModel() {
|
public void testCreateModel() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
String modelFilePath = "/somemodel";
|
String modelFilePath = "/somemodel";
|
||||||
try {
|
try {
|
||||||
doNothing().when(ollamaAPI).createModel(model, modelFilePath);
|
doNothing().when(ollamaAPI).createModel(model, modelFilePath);
|
||||||
ollamaAPI.createModel(model, modelFilePath);
|
ollamaAPI.createModel(model, modelFilePath);
|
||||||
verify(ollamaAPI, times(1)).createModel(model, modelFilePath);
|
verify(ollamaAPI, times(1)).createModel(model, modelFilePath);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDeleteModel() {
|
public void testDeleteModel() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
try {
|
try {
|
||||||
doNothing().when(ollamaAPI).deleteModel(model, true);
|
doNothing().when(ollamaAPI).deleteModel(model, true);
|
||||||
ollamaAPI.deleteModel(model, true);
|
ollamaAPI.deleteModel(model, true);
|
||||||
verify(ollamaAPI, times(1)).deleteModel(model, true);
|
verify(ollamaAPI, times(1)).deleteModel(model, true);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGetModelDetails() {
|
public void testGetModelDetails() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
|
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
|
||||||
ollamaAPI.getModelDetails(model);
|
ollamaAPI.getModelDetails(model);
|
||||||
verify(ollamaAPI, times(1)).getModelDetails(model);
|
verify(ollamaAPI, times(1)).getModelDetails(model);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGenerateEmbeddings() {
|
public void testGenerateEmbeddings() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
|
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
|
||||||
ollamaAPI.generateEmbeddings(model, prompt);
|
ollamaAPI.generateEmbeddings(model, prompt);
|
||||||
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
|
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAsk() {
|
public void testAsk() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0));
|
when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0, 200));
|
||||||
ollamaAPI.ask(model, prompt);
|
ollamaAPI.ask(model, prompt);
|
||||||
verify(ollamaAPI, times(1)).ask(model, prompt);
|
verify(ollamaAPI, times(1)).ask(model, prompt);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAskAsync() {
|
public void testAskAsync() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
when(ollamaAPI.askAsync(model, prompt)).thenReturn(new OllamaAsyncResultCallback(null, null, null));
|
when(ollamaAPI.askAsync(model, prompt))
|
||||||
ollamaAPI.askAsync(model, prompt);
|
.thenReturn(new OllamaAsyncResultCallback(null, null, null));
|
||||||
verify(ollamaAPI, times(1)).askAsync(model, prompt);
|
ollamaAPI.askAsync(model, prompt);
|
||||||
}
|
verify(ollamaAPI, times(1)).askAsync(model, prompt);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user