diff --git a/README.md b/README.md index 1c0fe60..1187962 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ In your Maven project, add this dependency: <dependency> <groupId>io.github.amithkoujalgi</groupId> <artifactId>ollama4j</artifactId> - <version>1.0.20</version> + <version>1.0.29</version> </dependency> ``` diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 11906b3..1bf4b9c 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -37,6 +37,8 @@ public class OllamaAPI { private final String host; private long requestTimeoutSeconds = 3; private boolean verbose = true; + private String username; + private String password; /** * Instantiates the Ollama API. @@ -64,6 +66,14 @@ public class OllamaAPI { this.verbose = verbose; } + /** + * + */ + public void setBasicAuth(String username, String password) { + this.username = username; + this.password = password; + } + /** * API to check the reachability of Ollama server. * @@ -306,17 +316,13 @@ public class OllamaAPI { */ public List<Double> generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException { - String url = this.host + "/api/embeddings"; + URI uri = URI.create(this.host + "/api/embeddings"); String jsonData = new ModelEmbeddingsRequest(model, prompt).toString(); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create(url)) + HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri) .header("Accept", "application/json") - .header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .POST(HttpRequest.BodyPublishers.ofString(jsonData)) - .build(); + .POST(HttpRequest.BodyPublishers.ofString(jsonData)); + HttpRequest request = requestBuilder.build(); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseBody = response.body(); @@ -426,14 +432,12 @@ public class OllamaAPI { long startTime = System.currentTimeMillis(); HttpClient httpClient = HttpClient.newHttpClient(); URI uri = URI.create(this.host + "/api/generate"); - HttpRequest request = - HttpRequest.newBuilder(uri) + HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri) .POST( HttpRequest.BodyPublishers.ofString( - Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) - .header("Content-Type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .build(); + Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))); + HttpRequest request = requestBuilder.build(); + logger.debug("Ask model '" + ollamaRequestModel + "' ..."); HttpResponse<InputStream> response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); @@ -444,10 +448,16 @@ public class OllamaAPI { String line; while ((line = reader.readLine()) != null) { if (statusCode == 404) { + logger.warn("Status code: 404 (Not Found)"); OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); responseBuffer.append(ollamaResponseModel.getError()); - } else { + } else if (statusCode == 401) { + logger.warn("Status code: 401 (Unauthorized)"); + OllamaErrorResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); + responseBuffer.append(ollamaResponseModel.getError()); + }else { OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); if (!ollamaResponseModel.isDone()) { @@ -457,10 +467,44 @@ public class OllamaAPI { } } if (statusCode != 200) { + logger.error("Status code " + statusCode + " instead 200"); throw new OllamaBaseException(responseBuffer.toString()); } else { long endTime = System.currentTimeMillis(); return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); } } + + /** + * + */ + private HttpRequest.Builder getRequestBuilderDefault(URI uri) { + HttpRequest.Builder requestBuilder = + HttpRequest.newBuilder(uri) + .header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)); + if (basicAuthCredentialsSet()) { + requestBuilder.header("Authorization", getBasicAuthHeaderValue()); + } + return requestBuilder; + } + + /** + * @return basic authentication header value (encoded credentials) + */ + private String getBasicAuthHeaderValue() { + String credentialsToEncode = username + ":" + password; + return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); + } + + /** + * @return true when Basic Auth credentials set + */ + private boolean basicAuthCredentialsSet() { + if (username != null && password != null) { + return true; + } else { + return false; + } + } }