Fixes to ask API

This commit is contained in:
Amith Koujalgi 2023-12-14 16:45:12 +05:30
parent f67f3b9eb5
commit 4e4a5d2996
2 changed files with 289 additions and 208 deletions

View File

@ -3,9 +3,6 @@ package io.github.amithkoujalgi.ollama4j.core;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.*; import io.github.amithkoujalgi.ollama4j.core.models.*;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -17,10 +14,10 @@ import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.List; import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** /** The base Ollama API class. */
* The base Ollama API class.
*/
@SuppressWarnings("DuplicatedCode") @SuppressWarnings("DuplicatedCode")
public class OllamaAPI { public class OllamaAPI {
@ -43,6 +40,7 @@ public class OllamaAPI {
/** /**
* Set/unset logging of responses * Set/unset logging of responses
*
* @param verbose true/false * @param verbose true/false
*/ */
public void setVerbose(boolean verbose) { public void setVerbose(boolean verbose) {
@ -54,38 +52,59 @@ public class OllamaAPI {
* *
* @return the list * @return the list
*/ */
public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { public List<Model> listModels()
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = this.host + "/api/tags"; String url = this.host + "/api/tags";
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = HttpRequest.newBuilder().uri(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); HttpRequest httpRequest =
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); HttpRequest.newBuilder()
.uri(new URI(url))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.GET()
.build();
HttpResponse<String> response =
httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseString = response.body(); String responseString = response.body();
if (statusCode == 200) { if (statusCode == 200) {
return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class).getModels(); return Utils.getObjectMapper()
.readValue(responseString, ListModelsResponse.class)
.getModels();
} else { } else {
throw new OllamaBaseException(statusCode + " - " + responseString); throw new OllamaBaseException(statusCode + " - " + responseString);
} }
} }
/** /**
* Pull a model on the Ollama server from the list of <a href="https://ollama.ai/library">available models</a>. * Pull a model on the Ollama server from the list of <a
* href="https://ollama.ai/library">available models</a>.
* *
* @param model the name of the model * @param model the name of the model
*/ */
public void pullModel(String model) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { public void pullModel(String model)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String url = this.host + "/api/pull"; String url = this.host + "/api/pull";
String jsonData = String.format("{\"name\": \"%s\"}", model); String jsonData = String.format("{\"name\": \"%s\"}", model);
HttpRequest request = HttpRequest.newBuilder().uri(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json").header("Content-type", "application/json").build(); HttpRequest request =
HttpRequest.newBuilder()
.uri(new URI(url))
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); HttpResponse<InputStream> response =
client.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
InputStream responseBodyStream = response.body(); InputStream responseBodyStream = response.body();
String responseString = ""; String responseString = "";
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class); ModelPullResponse modelPullResponse =
Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
if (verbose) { if (verbose) {
logger.info(modelPullResponse.getStatus()); logger.info(modelPullResponse.getStatus());
} }
@ -102,10 +121,17 @@ public class OllamaAPI {
* @param modelName the model * @param modelName the model
* @return the model details * @return the model details
*/ */
public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException { public ModelDetail getModelDetails(String modelName)
throws IOException, OllamaBaseException, InterruptedException {
String url = this.host + "/api/show"; String url = this.host + "/api/show";
String jsonData = String.format("{\"name\": \"%s\"}", modelName); String jsonData = String.format("{\"name\": \"%s\"}", modelName);
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpRequest request =
HttpRequest.newBuilder()
.uri(URI.create(url))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
.build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -118,16 +144,24 @@ public class OllamaAPI {
} }
/** /**
* Create a custom model from a model file. * Create a custom model from a model file. Read more about custom model file creation <a
* Read more about custom model file creation <a href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md">here</a>. * href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md">here</a>.
* *
* @param modelName the name of the custom model to be created. * @param modelName the name of the custom model to be created.
* @param modelFilePath the path to model file that exists on the Ollama server. * @param modelFilePath the path to model file that exists on the Ollama server.
*/ */
public void createModel(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException { public void createModel(String modelName, String modelFilePath)
throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/create"; String url = this.host + "/api/create";
String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath); String jsonData =
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); String.format("{\"name\": \"%s\", \"path\": \"%s\"}", modelName, modelFilePath);
HttpRequest request =
HttpRequest.newBuilder()
.uri(URI.create(url))
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -135,7 +169,8 @@ public class OllamaAPI {
if (statusCode != 200) { if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString); throw new OllamaBaseException(statusCode + " - " + responseString);
} }
// FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this if the issue is fixed in the Ollama API server. // FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this
// if the issue is fixed in the Ollama API server.
if (responseString.contains("error")) { if (responseString.contains("error")) {
throw new OllamaBaseException(responseString); throw new OllamaBaseException(responseString);
} }
@ -148,12 +183,20 @@ public class OllamaAPI {
* Delete a model from Ollama server. * Delete a model from Ollama server.
* *
* @param name the name of the model to be deleted. * @param name the name of the model to be deleted.
* @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama server. * @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama
* server.
*/ */
public void deleteModel(String name, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException { public void deleteModel(String name, boolean ignoreIfNotPresent)
throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/delete"; String url = this.host + "/api/delete";
String jsonData = String.format("{\"name\": \"%s\"}", name); String jsonData = String.format("{\"name\": \"%s\"}", name);
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build(); HttpRequest request =
HttpRequest.newBuilder()
.uri(URI.create(url))
.method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -173,16 +216,24 @@ public class OllamaAPI {
* @param prompt text to generate embeddings for * @param prompt text to generate embeddings for
* @return embeddings * @return embeddings
*/ */
public List<Double> generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException { public List<Double> generateEmbeddings(String model, String prompt)
throws IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/embeddings"; String url = this.host + "/api/embeddings";
String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt); String jsonData = String.format("{\"model\": \"%s\", \"prompt\": \"%s\"}", model, prompt);
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpRequest request =
HttpRequest.newBuilder()
.uri(URI.create(url))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
.build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseBody = response.body(); String responseBody = response.body();
if (statusCode == 200) { if (statusCode == 200) {
EmbeddingResponse embeddingResponse = Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class); EmbeddingResponse embeddingResponse =
Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class);
return embeddingResponse.getEmbedding(); return embeddingResponse.getEmbedding();
} else { } else {
throw new OllamaBaseException(statusCode + " - " + responseBody); throw new OllamaBaseException(statusCode + " - " + responseBody);
@ -196,36 +247,53 @@ public class OllamaAPI {
* @param promptText the prompt/question text * @param promptText the prompt/question text
* @return OllamaResult - that includes response text and time taken for response * @return OllamaResult - that includes response text and time taken for response
*/ */
public OllamaResult ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException { public OllamaResult ask(String ollamaModelType, String promptText)
throws OllamaBaseException, IOException, InterruptedException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build(); HttpRequest request =
HttpResponse<InputStream> response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); HttpRequest.newBuilder(uri)
.POST(
HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
.header("Content-Type", "application/json")
.build();
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
InputStream responseBodyStream = response.body(); InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder(); StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); if (statusCode == 404) {
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
OllamaResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
if (!ollamaResponseModel.getDone()) { if (!ollamaResponseModel.getDone()) {
responseBuffer.append(ollamaResponseModel.getResponse()); responseBuffer.append(ollamaResponseModel.getResponse());
} }
} }
} }
}
if (statusCode != 200) { if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseBuffer); throw new OllamaBaseException(responseBuffer.toString());
} else { } else {
long endTime = System.currentTimeMillis(); long endTime = System.currentTimeMillis();
return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime); return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
} }
} }
/** /**
* Ask a question to a model running on Ollama server and get a callback handle that can be used to check for status and get the response from the model later. * Ask a question to a model running on Ollama server and get a callback handle that can be used
* This would be an async/non-blocking call. * to check for status and get the response from the model later. This would be an
* async/non-blocking call.
* *
* @param ollamaModelType the ollama model to ask the question to * @param ollamaModelType the ollama model to ask the question to
* @param promptText the prompt/question text * @param promptText the prompt/question text
@ -235,7 +303,8 @@ public class OllamaAPI {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel); OllamaAsyncResultCallback ollamaAsyncResultCallback =
new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel);
ollamaAsyncResultCallback.start(); ollamaAsyncResultCallback.start();
return ollamaAsyncResultCallback; return ollamaAsyncResultCallback;
} }

View File

@ -9,11 +9,14 @@ import com.fasterxml.jackson.core.JsonProcessingException;
public class OllamaResult { public class OllamaResult {
private final String response; private final String response;
private int httpStatusCode;
private long responseTime = 0; private long responseTime = 0;
public OllamaResult(String response, long responseTime) { public OllamaResult(String response, long responseTime, int httpStatusCode) {
this.response = response; this.response = response;
this.responseTime = responseTime; this.responseTime = responseTime;
this.httpStatusCode = httpStatusCode;
} }
/** /**
@ -34,6 +37,15 @@ public class OllamaResult {
return responseTime; return responseTime;
} }
/**
* Get the response status code.
*
* @return int - response status code
*/
public int getHttpStatusCode() {
return httpStatusCode;
}
@Override @Override
public String toString() { public String toString() {
try { try {