diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 11906b3..ff01aa8 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -37,6 +37,8 @@ public class OllamaAPI { private final String host; private long requestTimeoutSeconds = 3; private boolean verbose = true; + private String username; + private String password; /** * Instantiates the Ollama API. @@ -64,6 +66,14 @@ public class OllamaAPI { this.verbose = verbose; } + /** + * + */ + public void setBasicAuth(String username, String password) { + this.username = username; + this.password = password; + } + /** * API to check the reachability of Ollama server. * @@ -426,14 +436,18 @@ public class OllamaAPI { long startTime = System.currentTimeMillis(); HttpClient httpClient = HttpClient.newHttpClient(); URI uri = URI.create(this.host + "/api/generate"); - HttpRequest request = + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri) .POST( HttpRequest.BodyPublishers.ofString( Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) .header("Content-Type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .build(); + .timeout(Duration.ofSeconds(requestTimeoutSeconds)); + if (basicAuthCredentialsSet()) { + requestBuilder.header("Authorization", getBasicAuthHeaderValue()); + } + HttpRequest request = requestBuilder.build(); + logger.debug("Ask model '" + ollamaRequestModel + "' ..."); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); @@ -444,10 +458,16 @@ public class OllamaAPI { String line; while ((line = reader.readLine()) != null) { if (statusCode == 404) { + logger.warn("Status code: 404 (Not Found)"); OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); responseBuffer.append(ollamaResponseModel.getError()); - } else { + } else if (statusCode == 401) { + logger.warn("Status code: 401 (Unauthorized)"); + OllamaErrorResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); + responseBuffer.append(ollamaResponseModel.getError()); + }else { OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); if (!ollamaResponseModel.isDone()) { @@ -457,10 +477,30 @@ public class OllamaAPI { } } if (statusCode != 200) { + logger.error("Status code " + statusCode + " instead 200"); throw new OllamaBaseException(responseBuffer.toString()); } else { long endTime = System.currentTimeMillis(); return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); } } + + /** + * @return basic authentication header value (encoded credentials) + */ + private String getBasicAuthHeaderValue() { + String credentialsToEncode = username + ":" + password; + return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); + } + + /** + * @return true when Basic Auth credentials set + */ + private boolean basicAuthCredentialsSet() { + if (username != null && password != null) { + return true; + } else { + return false; + } + } }