From 657593be096e92f3945eccfb376ac1aafb70f32d Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Fri, 29 Dec 2023 10:25:18 +0530 Subject: [PATCH] Updated all APIs to use `getRequestBuilderDefault()` method --- docs/docs/apis-extras/basic-auth.md | 24 +++++ .../ollama4j/core/OllamaAPI.java | 102 +++++++++--------- .../ollama4j/core/models/BasicAuth.java | 13 +++ .../models/OllamaAsyncResultCallback.java | 66 +++++------- .../ollama4j/core/utils/Utils.java | 6 +- .../ollama4j/unittests/TestMockedAPIs.java | 41 ++++++- 6 files changed, 152 insertions(+), 100 deletions(-) create mode 100644 docs/docs/apis-extras/basic-auth.md create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/BasicAuth.java diff --git a/docs/docs/apis-extras/basic-auth.md b/docs/docs/apis-extras/basic-auth.md new file mode 100644 index 0000000..226a18f --- /dev/null +++ b/docs/docs/apis-extras/basic-auth.md @@ -0,0 +1,24 @@ +--- +sidebar_position: 2 +--- + +# Set Basic Authentication + +This API lets you set the basic authentication for the Ollama client. This would help in scenarios where +Ollama server would be setup behind a gateway/reverse proxy with basic auth. + +After configuring basic authentication, all subsequent requests will include the Basic Auth header. + +```java +public class Main { + + public static void main(String[] args) { + + String host = "http://localhost:11434/"; + + OllamaAPI ollamaAPI = new OllamaAPI(host); + + ollamaAPI.setBasicAuth("username", "password"); + } +} +``` \ No newline at end of file diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 1bf4b9c..a801793 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -37,8 +37,7 @@ public class OllamaAPI { private final String host; private long requestTimeoutSeconds = 3; private boolean verbose = true; - private String username; - private String password; + private BasicAuth basicAuth; /** * Instantiates the Ollama API. @@ -53,6 +52,11 @@ public class OllamaAPI { } } + /** + * Set request timeout in seconds. Default is 3 seconds. + * + * @param requestTimeoutSeconds the request timeout in seconds + */ public void setRequestTimeoutSeconds(long requestTimeoutSeconds) { this.requestTimeoutSeconds = requestTimeoutSeconds; } @@ -67,11 +71,13 @@ public class OllamaAPI { } /** + * Set basic authentication for accessing Ollama server that's behind a reverse-proxy/gateway. * + * @param username the username + * @param password the password */ public void setBasicAuth(String username, String password) { - this.username = username; - this.password = password; + this.basicAuth = new BasicAuth(username, password); } /** @@ -85,11 +91,9 @@ public class OllamaAPI { HttpRequest httpRequest = null; try { httpRequest = - HttpRequest.newBuilder() - .uri(new URI(url)) + getRequestBuilderDefault(new URI(url)) .header("Accept", "application/json") .header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .GET() .build(); } catch (URISyntaxException e) { @@ -117,11 +121,9 @@ public class OllamaAPI { String url = this.host + "/api/tags"; HttpClient httpClient = HttpClient.newHttpClient(); HttpRequest httpRequest = - HttpRequest.newBuilder() - .uri(new URI(url)) + getRequestBuilderDefault(new URI(url)) .header("Accept", "application/json") .header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .GET() .build(); HttpResponse response = @@ -148,12 +150,10 @@ public class OllamaAPI { String url = this.host + "/api/pull"; String jsonData = new ModelRequest(modelName).toString(); HttpRequest request = - HttpRequest.newBuilder() - .uri(new URI(url)) + getRequestBuilderDefault(new URI(url)) .POST(HttpRequest.BodyPublishers.ofString(jsonData)) .header("Accept", "application/json") .header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = @@ -184,15 +184,13 @@ public class OllamaAPI { * @return the model details */ public ModelDetail getModelDetails(String modelName) - throws IOException, OllamaBaseException, InterruptedException { + throws IOException, OllamaBaseException, InterruptedException, URISyntaxException { String url = this.host + "/api/show"; String jsonData = new ModelRequest(modelName).toString(); HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create(url)) + getRequestBuilderDefault(new URI(url)) .header("Accept", "application/json") .header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .POST(HttpRequest.BodyPublishers.ofString(jsonData)) .build(); HttpClient client = HttpClient.newHttpClient(); @@ -214,15 +212,13 @@ public class OllamaAPI { * @param modelFilePath the path to model file that exists on the Ollama server. */ public void createModelWithFilePath(String modelName, String modelFilePath) - throws IOException, InterruptedException, OllamaBaseException { + throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString(); HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create(url)) + getRequestBuilderDefault(new URI(url)) .header("Accept", "application/json") .header("Content-Type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) .build(); HttpClient client = HttpClient.newHttpClient(); @@ -250,15 +246,13 @@ public class OllamaAPI { * @param modelFileContents the path to model file that exists on the Ollama server. */ public void createModelWithModelFileContents(String modelName, String modelFileContents) - throws IOException, InterruptedException, OllamaBaseException { + throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString(); HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create(url)) + getRequestBuilderDefault(new URI(url)) .header("Accept", "application/json") .header("Content-Type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) .build(); HttpClient client = HttpClient.newHttpClient(); @@ -280,20 +274,17 @@ public class OllamaAPI { * Delete a model from Ollama server. * * @param modelName the name of the model to be deleted. - * @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama - * server. + * @param ignoreIfNotPresent ignore errors if the specified model is not present on Ollama server. */ public void deleteModel(String modelName, boolean ignoreIfNotPresent) - throws IOException, InterruptedException, OllamaBaseException { + throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/delete"; String jsonData = new ModelRequest(modelName).toString(); HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create(url)) + getRequestBuilderDefault(new URI(url)) .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) .header("Accept", "application/json") .header("Content-type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); @@ -319,7 +310,8 @@ public class OllamaAPI { URI uri = URI.create(this.host + "/api/embeddings"); String jsonData = new ModelEmbeddingsRequest(model, prompt).toString(); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri) + HttpRequest.Builder requestBuilder = + getRequestBuilderDefault(uri) .header("Accept", "application/json") .POST(HttpRequest.BodyPublishers.ofString(jsonData)); HttpRequest request = requestBuilder.build(); @@ -340,7 +332,7 @@ public class OllamaAPI { * * @param model the ollama model to ask the question to * @param promptText the prompt/question text - * @return OllamaResult - that includes response text and time taken for response + * @return OllamaResult that includes response text and time taken for response */ public OllamaResult ask(String model, String promptText) throws OllamaBaseException, IOException, InterruptedException { @@ -359,10 +351,11 @@ public class OllamaAPI { */ public OllamaAsyncResultCallback askAsync(String model, String promptText) { OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText); - HttpClient httpClient = HttpClient.newHttpClient(); + URI uri = URI.create(this.host + "/api/generate"); OllamaAsyncResultCallback ollamaAsyncResultCallback = - new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel, requestTimeoutSeconds); + new OllamaAsyncResultCallback( + getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds); ollamaAsyncResultCallback.start(); return ollamaAsyncResultCallback; } @@ -374,7 +367,7 @@ public class OllamaAPI { * @param model the ollama model to ask the question to * @param promptText the prompt/question text * @param imageFiles the list of image files to use for the question - * @return OllamaResult - that includes response text and time taken for response + * @return OllamaResult that includes response text and time taken for response */ public OllamaResult askWithImageFiles(String model, String promptText, List imageFiles) throws OllamaBaseException, IOException, InterruptedException { @@ -393,7 +386,7 @@ public class OllamaAPI { * @param model the ollama model to ask the question to * @param promptText the prompt/question text * @param imageURLs the list of image URLs to use for the question - * @return OllamaResult - that includes response text and time taken for response + * @return OllamaResult that includes response text and time taken for response */ public OllamaResult askWithImageURLs(String model, String promptText, List imageURLs) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { @@ -432,7 +425,8 @@ public class OllamaAPI { long startTime = System.currentTimeMillis(); HttpClient httpClient = HttpClient.newHttpClient(); URI uri = URI.create(this.host + "/api/generate"); - HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri) + HttpRequest.Builder requestBuilder = + getRequestBuilderDefault(uri) .POST( HttpRequest.BodyPublishers.ofString( Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))); @@ -455,9 +449,10 @@ public class OllamaAPI { } else if (statusCode == 401) { logger.warn("Status code: 401 (Unauthorized)"); OllamaErrorResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); + Utils.getObjectMapper() + .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); responseBuffer.append(ollamaResponseModel.getError()); - }else { + } else { OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); if (!ollamaResponseModel.isDone()) { @@ -467,7 +462,7 @@ public class OllamaAPI { } } if (statusCode != 200) { - logger.error("Status code " + statusCode + " instead 200"); + logger.error("Status code " + statusCode); throw new OllamaBaseException(responseBuffer.toString()); } else { long endTime = System.currentTimeMillis(); @@ -476,35 +471,38 @@ public class OllamaAPI { } /** + * Get default request builder. * + * @param uri URI to get a HttpRequest.Builder + * @return HttpRequest.Builder */ private HttpRequest.Builder getRequestBuilderDefault(URI uri) { HttpRequest.Builder requestBuilder = - HttpRequest.newBuilder(uri) - .header("Content-Type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)); - if (basicAuthCredentialsSet()) { + HttpRequest.newBuilder(uri) + .header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)); + if (isBasicAuthCredentialsSet()) { requestBuilder.header("Authorization", getBasicAuthHeaderValue()); } return requestBuilder; } /** + * Get basic authentication header value. + * * @return basic authentication header value (encoded credentials) */ private String getBasicAuthHeaderValue() { - String credentialsToEncode = username + ":" + password; + String credentialsToEncode = basicAuth.getUsername() + ":" + basicAuth.getPassword(); return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); } /** + * Check if Basic Auth credentials set. + * * @return true when Basic Auth credentials set */ - private boolean basicAuthCredentialsSet() { - if (username != null && password != null) { - return true; - } else { - return false; - } + private boolean isBasicAuthCredentialsSet() { + return basicAuth != null; } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/BasicAuth.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/BasicAuth.java new file mode 100644 index 0000000..dbcf8a7 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/BasicAuth.java @@ -0,0 +1,13 @@ +package io.github.amithkoujalgi.ollama4j.core.models; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class BasicAuth { + private String username; + private String password; +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java index ebf5323..b412f31 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java @@ -6,7 +6,6 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; -import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; @@ -14,30 +13,44 @@ import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.LinkedList; import java.util.Queue; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; +@Data +@EqualsAndHashCode(callSuper = true) @SuppressWarnings("unused") public class OllamaAsyncResultCallback extends Thread { - private final HttpClient client; - private final URI uri; + private final HttpRequest.Builder requestBuilder; private final OllamaRequestModel ollamaRequestModel; private final Queue queue = new LinkedList<>(); private String result; private boolean isDone; - private boolean succeeded; + + /** + * -- GETTER -- Returns the status of the request. Indicates if the request was successful or a + * failure. If the request was a failure, the `getResponse()` method will return the error + * message. + */ + @Getter private boolean succeeded; private long requestTimeoutSeconds; - private int httpStatusCode; - private long responseTime = 0; + /** + * -- GETTER -- Returns the HTTP response status code for the request that was made to Ollama + * server. + */ + @Getter private int httpStatusCode; + + /** -- GETTER -- Returns the response time in milliseconds. */ + @Getter private long responseTime = 0; public OllamaAsyncResultCallback( - HttpClient client, - URI uri, + HttpRequest.Builder requestBuilder, OllamaRequestModel ollamaRequestModel, long requestTimeoutSeconds) { - this.client = client; + this.requestBuilder = requestBuilder; this.ollamaRequestModel = ollamaRequestModel; - this.uri = uri; this.isDone = false; this.result = ""; this.queue.add(""); @@ -46,10 +59,11 @@ public class OllamaAsyncResultCallback extends Thread { @Override public void run() { + HttpClient httpClient = HttpClient.newHttpClient(); try { long startTime = System.currentTimeMillis(); HttpRequest request = - HttpRequest.newBuilder(uri) + requestBuilder .POST( HttpRequest.BodyPublishers.ofString( Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) @@ -57,7 +71,7 @@ public class OllamaAsyncResultCallback extends Thread { .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .build(); HttpResponse response = - client.send(request, HttpResponse.BodyHandlers.ofInputStream()); + httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); this.httpStatusCode = statusCode; @@ -108,25 +122,6 @@ public class OllamaAsyncResultCallback extends Thread { return isDone; } - /** - * Returns the HTTP response status code for the request that was made to Ollama server. - * - * @return int - the status code for the request - */ - public int getHttpStatusCode() { - return httpStatusCode; - } - - /** - * Returns the status of the request. Indicates if the request was successful or a failure. If the - * request was a failure, the `getResponse()` method will return the error message. - * - * @return boolean - status - */ - public boolean isSucceeded() { - return succeeded; - } - /** * Returns the final response when the execution completes. Does not return intermediate results. * @@ -140,15 +135,6 @@ public class OllamaAsyncResultCallback extends Thread { return queue; } - /** - * Returns the response time in milliseconds. - * - * @return long - response time in milliseconds. - */ - public long getResponseTime() { - return responseTime; - } - public void setRequestTimeoutSeconds(long requestTimeoutSeconds) { this.requestTimeoutSeconds = requestTimeoutSeconds; } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Utils.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Utils.java index 190683c..9be49e1 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Utils.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Utils.java @@ -3,7 +3,7 @@ package io.github.amithkoujalgi.ollama4j.core.utils; import com.fasterxml.jackson.databind.ObjectMapper; public class Utils { - public static ObjectMapper getObjectMapper() { - return new ObjectMapper(); - } + public static ObjectMapper getObjectMapper() { + return new ObjectMapper(); + } } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java index 7c46e37..b7c9977 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java @@ -11,12 +11,13 @@ import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; import java.io.IOException; import java.net.URISyntaxException; import java.util.ArrayList; +import java.util.Collections; import org.junit.jupiter.api.Test; import org.mockito.Mockito; class TestMockedAPIs { @Test - void testMockPullModel() { + void testPullModel() { OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); String model = OllamaModelType.LLAMA2; try { @@ -49,7 +50,7 @@ class TestMockedAPIs { doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath); ollamaAPI.createModelWithModelFileContents(model, modelFilePath); verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath); - } catch (IOException | OllamaBaseException | InterruptedException e) { + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { throw new RuntimeException(e); } } @@ -62,7 +63,7 @@ class TestMockedAPIs { doNothing().when(ollamaAPI).deleteModel(model, true); ollamaAPI.deleteModel(model, true); verify(ollamaAPI, times(1)).deleteModel(model, true); - } catch (IOException | OllamaBaseException | InterruptedException e) { + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { throw new RuntimeException(e); } } @@ -75,7 +76,7 @@ class TestMockedAPIs { when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail()); ollamaAPI.getModelDetails(model); verify(ollamaAPI, times(1)).getModelDetails(model); - } catch (IOException | OllamaBaseException | InterruptedException e) { + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { throw new RuntimeException(e); } } @@ -108,13 +109,43 @@ class TestMockedAPIs { } } + @Test + void testAskWithImageFiles() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + String prompt = "some prompt text"; + try { + when(ollamaAPI.askWithImageFiles(model, prompt, Collections.emptyList())) + .thenReturn(new OllamaResult("", 0, 200)); + ollamaAPI.askWithImageFiles(model, prompt, Collections.emptyList()); + verify(ollamaAPI, times(1)).askWithImageFiles(model, prompt, Collections.emptyList()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Test + void testAskWithImageURLs() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + String prompt = "some prompt text"; + try { + when(ollamaAPI.askWithImageURLs(model, prompt, Collections.emptyList())) + .thenReturn(new OllamaResult("", 0, 200)); + ollamaAPI.askWithImageURLs(model, prompt, Collections.emptyList()); + verify(ollamaAPI, times(1)).askWithImageURLs(model, prompt, Collections.emptyList()); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + throw new RuntimeException(e); + } + } + @Test void testAskAsync() { OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); String model = OllamaModelType.LLAMA2; String prompt = "some prompt text"; when(ollamaAPI.askAsync(model, prompt)) - .thenReturn(new OllamaAsyncResultCallback(null, null, null, 3)); + .thenReturn(new OllamaAsyncResultCallback(null, null, 3)); ollamaAPI.askAsync(model, prompt); verify(ollamaAPI, times(1)).askAsync(model, prompt); }