From 3a792090e26150496962860413d8849c9152fb76 Mon Sep 17 00:00:00 2001 From: Sven Strickroth Date: Mon, 10 Mar 2025 14:39:54 +0100 Subject: [PATCH 1/2] Support bearer token May be use as follows: ``` ollamaAPI.setBasicAuth(new BasicAuth() { @Override public String getBasicAuthHeaderValue() { return "Bearer [sometext]"; } }); ``` Signed-off-by: Sven Strickroth --- .../java/io/github/ollama4j/OllamaAPI.java | 16 +++------ .../ollama4j/models/request/BasicAuth.java | 12 +++++++ .../models/request/OllamaEndpointCaller.java | 36 ++++--------------- 3 files changed, 24 insertions(+), 40 deletions(-) diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 76af4c8..d16bcad 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -109,6 +109,10 @@ public class OllamaAPI { this.basicAuth = new BasicAuth(username, password); } + public void setBasicAuth(BasicAuth basicAuth) { + this.basicAuth = basicAuth; + } + /** * API to check the reachability of Ollama server. * @@ -1083,21 +1087,11 @@ public class OllamaAPI { private HttpRequest.Builder getRequestBuilderDefault(URI uri) { HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds)); if (isBasicAuthCredentialsSet()) { - requestBuilder.header("Authorization", getBasicAuthHeaderValue()); + requestBuilder.header("Authorization", basicAuth.getBasicAuthHeaderValue()); } return requestBuilder; } - /** - * Get basic authentication header value. - * - * @return basic authentication header value (encoded credentials) - */ - private String getBasicAuthHeaderValue() { - String credentialsToEncode = basicAuth.getUsername() + ":" + basicAuth.getPassword(); - return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); - } - /** * Check if Basic Auth credentials set. * diff --git a/src/main/java/io/github/ollama4j/models/request/BasicAuth.java b/src/main/java/io/github/ollama4j/models/request/BasicAuth.java index f3372a9..683ed38 100644 --- a/src/main/java/io/github/ollama4j/models/request/BasicAuth.java +++ b/src/main/java/io/github/ollama4j/models/request/BasicAuth.java @@ -1,5 +1,7 @@ package io.github.ollama4j.models.request; +import java.util.Base64; + import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; @@ -10,4 +12,14 @@ import lombok.NoArgsConstructor; public class BasicAuth { private String username; private String password; + + /** + * Get basic authentication header value. + * + * @return basic authentication header value (encoded credentials) + */ + public String getBasicAuthHeaderValue() { + final String credentialsToEncode = this.getUsername() + ":" + this.getPassword(); + return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); + } } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java index e9d0e0d..a9a9a8b 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java @@ -1,26 +1,14 @@ package io.github.ollama4j.models.request; -import io.github.ollama4j.OllamaAPI; -import io.github.ollama4j.exceptions.OllamaBaseException; -import io.github.ollama4j.models.response.OllamaErrorResponse; -import io.github.ollama4j.models.response.OllamaResult; -import io.github.ollama4j.utils.OllamaRequestBody; -import io.github.ollama4j.utils.Utils; -import lombok.Getter; +import java.net.URI; +import java.net.http.HttpRequest; +import java.time.Duration; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.Base64; +import io.github.ollama4j.OllamaAPI; +import lombok.Getter; /** * Abstract helperclass to call the ollama api server. @@ -59,21 +47,11 @@ public abstract class OllamaEndpointCaller { .header("Content-Type", "application/json") .timeout(Duration.ofSeconds(this.requestTimeoutSeconds)); if (isBasicAuthCredentialsSet()) { - requestBuilder.header("Authorization", getBasicAuthHeaderValue()); + requestBuilder.header("Authorization", this.basicAuth.getBasicAuthHeaderValue()); } return requestBuilder; } - /** - * Get basic authentication header value. - * - * @return basic authentication header value (encoded credentials) - */ - protected String getBasicAuthHeaderValue() { - String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword(); - return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); - } - /** * Check if Basic Auth credentials set. * From 138497b30fec62046fdf7e4747cad219ae012a79 Mon Sep 17 00:00:00 2001 From: Sven Strickroth Date: Mon, 10 Mar 2025 14:51:37 +0100 Subject: [PATCH 2/2] Introduce BearerAuth class Signed-off-by: Sven Strickroth --- .../java/io/github/ollama4j/OllamaAPI.java | 21 ++++++++++++------- .../github/ollama4j/models/request/Auth.java | 10 +++++++++ .../ollama4j/models/request/BasicAuth.java | 5 ++--- .../ollama4j/models/request/BearerAuth.java | 19 +++++++++++++++++ .../request/OllamaChatEndpointCaller.java | 4 ++-- .../models/request/OllamaEndpointCaller.java | 18 ++++++++-------- .../request/OllamaGenerateEndpointCaller.java | 2 +- 7 files changed, 56 insertions(+), 23 deletions(-) create mode 100644 src/main/java/io/github/ollama4j/models/request/Auth.java create mode 100644 src/main/java/io/github/ollama4j/models/request/BearerAuth.java diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index d16bcad..4a4872d 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -72,7 +72,7 @@ public class OllamaAPI { @Setter private int maxChatToolCallRetries = 3; - private BasicAuth basicAuth; + private Auth auth; private final ToolRegistry toolRegistry = new ToolRegistry(); @@ -106,11 +106,16 @@ public class OllamaAPI { * @param password the password */ public void setBasicAuth(String username, String password) { - this.basicAuth = new BasicAuth(username, password); + this.auth = new BasicAuth(username, password); } - public void setBasicAuth(BasicAuth basicAuth) { - this.basicAuth = basicAuth; + /** + * Set Bearer authentication for accessing Ollama server that's behind a reverse-proxy/gateway. + * + * @param bearerToken the Bearer authentication token to provide + */ + public void setBearerAuth(String bearerToken) { + this.auth = new BearerAuth(bearerToken); } /** @@ -860,7 +865,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted */ public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException { - OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); + OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, verbose); OllamaChatResult result; // add all registered tools to Request @@ -1066,7 +1071,7 @@ public class OllamaAPI { * @throws InterruptedException if the thread is interrupted during the request. */ private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { - OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); + OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, verbose); OllamaResult result; if (streamHandler != null) { ollamaRequestModel.setStream(true); @@ -1087,7 +1092,7 @@ public class OllamaAPI { private HttpRequest.Builder getRequestBuilderDefault(URI uri) { HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds)); if (isBasicAuthCredentialsSet()) { - requestBuilder.header("Authorization", basicAuth.getBasicAuthHeaderValue()); + requestBuilder.header("Authorization", auth.getAuthHeaderValue()); } return requestBuilder; } @@ -1098,7 +1103,7 @@ public class OllamaAPI { * @return true when Basic Auth credentials set */ private boolean isBasicAuthCredentialsSet() { - return basicAuth != null; + return auth != null; } private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException { diff --git a/src/main/java/io/github/ollama4j/models/request/Auth.java b/src/main/java/io/github/ollama4j/models/request/Auth.java new file mode 100644 index 0000000..70c9c1b --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/request/Auth.java @@ -0,0 +1,10 @@ +package io.github.ollama4j.models.request; + +public abstract class Auth { + /** + * Get authentication header value. + * + * @return authentication header value + */ + public abstract String getAuthHeaderValue(); +} diff --git a/src/main/java/io/github/ollama4j/models/request/BasicAuth.java b/src/main/java/io/github/ollama4j/models/request/BasicAuth.java index 683ed38..c58b240 100644 --- a/src/main/java/io/github/ollama4j/models/request/BasicAuth.java +++ b/src/main/java/io/github/ollama4j/models/request/BasicAuth.java @@ -7,9 +7,8 @@ import lombok.Data; import lombok.NoArgsConstructor; @Data -@NoArgsConstructor @AllArgsConstructor -public class BasicAuth { +public class BasicAuth extends Auth { private String username; private String password; @@ -18,7 +17,7 @@ public class BasicAuth { * * @return basic authentication header value (encoded credentials) */ - public String getBasicAuthHeaderValue() { + public String getAuthHeaderValue() { final String credentialsToEncode = this.getUsername() + ":" + this.getPassword(); return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); } diff --git a/src/main/java/io/github/ollama4j/models/request/BearerAuth.java b/src/main/java/io/github/ollama4j/models/request/BearerAuth.java new file mode 100644 index 0000000..8236042 --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/request/BearerAuth.java @@ -0,0 +1,19 @@ +package io.github.ollama4j.models.request; + +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class BearerAuth extends Auth { + private String bearerToken; + + /** + * Get authentication header value. + * + * @return authentication header value with bearer token + */ + public String getAuthHeaderValue() { + return "Bearer "+ bearerToken; + } +} diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java index a1a6216..09a3870 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java @@ -30,8 +30,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { private OllamaTokenHandler tokenHandler; - public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { - super(host, basicAuth, requestTimeoutSeconds, verbose); + public OllamaChatEndpointCaller(String host, Auth auth, long requestTimeoutSeconds, boolean verbose) { + super(host, auth, requestTimeoutSeconds, verbose); } @Override diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java index a9a9a8b..1f42ef8 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java @@ -19,13 +19,13 @@ public abstract class OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class); private final String host; - private final BasicAuth basicAuth; + private final Auth auth; private final long requestTimeoutSeconds; private final boolean verbose; - public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { + public OllamaEndpointCaller(String host, Auth auth, long requestTimeoutSeconds, boolean verbose) { this.host = host; - this.basicAuth = basicAuth; + this.auth = auth; this.requestTimeoutSeconds = requestTimeoutSeconds; this.verbose = verbose; } @@ -46,19 +46,19 @@ public abstract class OllamaEndpointCaller { HttpRequest.newBuilder(uri) .header("Content-Type", "application/json") .timeout(Duration.ofSeconds(this.requestTimeoutSeconds)); - if (isBasicAuthCredentialsSet()) { - requestBuilder.header("Authorization", this.basicAuth.getBasicAuthHeaderValue()); + if (isAuthCredentialsSet()) { + requestBuilder.header("Authorization", this.auth.getAuthHeaderValue()); } return requestBuilder; } /** - * Check if Basic Auth credentials set. + * Check if Auth credentials set. * - * @return true when Basic Auth credentials set + * @return true when Auth credentials set */ - protected boolean isBasicAuthCredentialsSet() { - return this.basicAuth != null; + protected boolean isAuthCredentialsSet() { + return this.auth != null; } } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java index 00b2b12..461ec75 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java @@ -28,7 +28,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { private OllamaGenerateStreamObserver streamObserver; - public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { + public OllamaGenerateEndpointCaller(String host, Auth basicAuth, long requestTimeoutSeconds, boolean verbose) { super(host, basicAuth, requestTimeoutSeconds, verbose); }