From 138497b30fec62046fdf7e4747cad219ae012a79 Mon Sep 17 00:00:00 2001 From: Sven Strickroth Date: Mon, 10 Mar 2025 14:51:37 +0100 Subject: [PATCH] Introduce BearerAuth class Signed-off-by: Sven Strickroth --- .../java/io/github/ollama4j/OllamaAPI.java | 21 ++++++++++++------- .../github/ollama4j/models/request/Auth.java | 10 +++++++++ .../ollama4j/models/request/BasicAuth.java | 5 ++--- .../ollama4j/models/request/BearerAuth.java | 19 +++++++++++++++++ .../request/OllamaChatEndpointCaller.java | 4 ++-- .../models/request/OllamaEndpointCaller.java | 18 ++++++++-------- .../request/OllamaGenerateEndpointCaller.java | 2 +- 7 files changed, 56 insertions(+), 23 deletions(-) create mode 100644 src/main/java/io/github/ollama4j/models/request/Auth.java create mode 100644 src/main/java/io/github/ollama4j/models/request/BearerAuth.java diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index d16bcad..4a4872d 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -72,7 +72,7 @@ public class OllamaAPI { @Setter private int maxChatToolCallRetries = 3; - private BasicAuth basicAuth; + private Auth auth; private final ToolRegistry toolRegistry = new ToolRegistry(); @@ -106,11 +106,16 @@ public class OllamaAPI { * @param password the password */ public void setBasicAuth(String username, String password) { - this.basicAuth = new BasicAuth(username, password); + this.auth = new BasicAuth(username, password); } - public void setBasicAuth(BasicAuth basicAuth) { - this.basicAuth = basicAuth; + /** + * Set Bearer authentication for accessing Ollama server that's behind a reverse-proxy/gateway. + * + * @param bearerToken the Bearer authentication token to provide + */ + public void setBearerAuth(String bearerToken) { + this.auth = new BearerAuth(bearerToken); } /** @@ -860,7 +865,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted */ public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException { - OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); + OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, verbose); OllamaChatResult result; // add all registered tools to Request @@ -1066,7 +1071,7 @@ public class OllamaAPI { * @throws InterruptedException if the thread is interrupted during the request. */ private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { - OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); + OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, verbose); OllamaResult result; if (streamHandler != null) { ollamaRequestModel.setStream(true); @@ -1087,7 +1092,7 @@ public class OllamaAPI { private HttpRequest.Builder getRequestBuilderDefault(URI uri) { HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds)); if (isBasicAuthCredentialsSet()) { - requestBuilder.header("Authorization", basicAuth.getBasicAuthHeaderValue()); + requestBuilder.header("Authorization", auth.getAuthHeaderValue()); } return requestBuilder; } @@ -1098,7 +1103,7 @@ public class OllamaAPI { * @return true when Basic Auth credentials set */ private boolean isBasicAuthCredentialsSet() { - return basicAuth != null; + return auth != null; } private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException { diff --git a/src/main/java/io/github/ollama4j/models/request/Auth.java b/src/main/java/io/github/ollama4j/models/request/Auth.java new file mode 100644 index 0000000..70c9c1b --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/request/Auth.java @@ -0,0 +1,10 @@ +package io.github.ollama4j.models.request; + +public abstract class Auth { + /** + * Get authentication header value. + * + * @return authentication header value + */ + public abstract String getAuthHeaderValue(); +} diff --git a/src/main/java/io/github/ollama4j/models/request/BasicAuth.java b/src/main/java/io/github/ollama4j/models/request/BasicAuth.java index 683ed38..c58b240 100644 --- a/src/main/java/io/github/ollama4j/models/request/BasicAuth.java +++ b/src/main/java/io/github/ollama4j/models/request/BasicAuth.java @@ -7,9 +7,8 @@ import lombok.Data; import lombok.NoArgsConstructor; @Data -@NoArgsConstructor @AllArgsConstructor -public class BasicAuth { +public class BasicAuth extends Auth { private String username; private String password; @@ -18,7 +17,7 @@ public class BasicAuth { * * @return basic authentication header value (encoded credentials) */ - public String getBasicAuthHeaderValue() { + public String getAuthHeaderValue() { final String credentialsToEncode = this.getUsername() + ":" + this.getPassword(); return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); } diff --git a/src/main/java/io/github/ollama4j/models/request/BearerAuth.java b/src/main/java/io/github/ollama4j/models/request/BearerAuth.java new file mode 100644 index 0000000..8236042 --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/request/BearerAuth.java @@ -0,0 +1,19 @@ +package io.github.ollama4j.models.request; + +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class BearerAuth extends Auth { + private String bearerToken; + + /** + * Get authentication header value. + * + * @return authentication header value with bearer token + */ + public String getAuthHeaderValue() { + return "Bearer "+ bearerToken; + } +} diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java index a1a6216..09a3870 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java @@ -30,8 +30,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { private OllamaTokenHandler tokenHandler; - public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { - super(host, basicAuth, requestTimeoutSeconds, verbose); + public OllamaChatEndpointCaller(String host, Auth auth, long requestTimeoutSeconds, boolean verbose) { + super(host, auth, requestTimeoutSeconds, verbose); } @Override diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java index a9a9a8b..1f42ef8 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java @@ -19,13 +19,13 @@ public abstract class OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class); private final String host; - private final BasicAuth basicAuth; + private final Auth auth; private final long requestTimeoutSeconds; private final boolean verbose; - public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { + public OllamaEndpointCaller(String host, Auth auth, long requestTimeoutSeconds, boolean verbose) { this.host = host; - this.basicAuth = basicAuth; + this.auth = auth; this.requestTimeoutSeconds = requestTimeoutSeconds; this.verbose = verbose; } @@ -46,19 +46,19 @@ public abstract class OllamaEndpointCaller { HttpRequest.newBuilder(uri) .header("Content-Type", "application/json") .timeout(Duration.ofSeconds(this.requestTimeoutSeconds)); - if (isBasicAuthCredentialsSet()) { - requestBuilder.header("Authorization", this.basicAuth.getBasicAuthHeaderValue()); + if (isAuthCredentialsSet()) { + requestBuilder.header("Authorization", this.auth.getAuthHeaderValue()); } return requestBuilder; } /** - * Check if Basic Auth credentials set. + * Check if Auth credentials set. * - * @return true when Basic Auth credentials set + * @return true when Auth credentials set */ - protected boolean isBasicAuthCredentialsSet() { - return this.basicAuth != null; + protected boolean isAuthCredentialsSet() { + return this.auth != null; } } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java index 00b2b12..461ec75 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java @@ -28,7 +28,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { private OllamaGenerateStreamObserver streamObserver; - public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { + public OllamaGenerateEndpointCaller(String host, Auth basicAuth, long requestTimeoutSeconds, boolean verbose) { super(host, basicAuth, requestTimeoutSeconds, verbose); }