diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 5d64044..b0d21c4 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -72,7 +72,7 @@ public class OllamaAPI { @Setter private int maxChatToolCallRetries = 3; - private BasicAuth basicAuth; + private Auth auth; private final ToolRegistry toolRegistry = new ToolRegistry(); @@ -106,7 +106,16 @@ public class OllamaAPI { * @param password the password */ public void setBasicAuth(String username, String password) { - this.basicAuth = new BasicAuth(username, password); + this.auth = new BasicAuth(username, password); + } + + /** + * Set Bearer authentication for accessing Ollama server that's behind a reverse-proxy/gateway. + * + * @param bearerToken the Bearer authentication token to provide + */ + public void setBearerAuth(String bearerToken) { + this.auth = new BearerAuth(bearerToken); } /** @@ -888,7 +897,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted */ public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException { - OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); + OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, verbose); OllamaChatResult result; // add all registered tools to Request @@ -1094,7 +1103,7 @@ public class OllamaAPI { * @throws InterruptedException if the thread is interrupted during the request. */ private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { - OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); + OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, verbose); OllamaResult result; if (streamHandler != null) { ollamaRequestModel.setStream(true); @@ -1115,28 +1124,18 @@ public class OllamaAPI { private HttpRequest.Builder getRequestBuilderDefault(URI uri) { HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds)); if (isBasicAuthCredentialsSet()) { - requestBuilder.header("Authorization", getBasicAuthHeaderValue()); + requestBuilder.header("Authorization", auth.getAuthHeaderValue()); } return requestBuilder; } - /** - * Get basic authentication header value. - * - * @return basic authentication header value (encoded credentials) - */ - private String getBasicAuthHeaderValue() { - String credentialsToEncode = basicAuth.getUsername() + ":" + basicAuth.getPassword(); - return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); - } - /** * Check if Basic Auth credentials set. * * @return true when Basic Auth credentials set */ private boolean isBasicAuthCredentialsSet() { - return basicAuth != null; + return auth != null; } private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException { diff --git a/src/main/java/io/github/ollama4j/models/request/Auth.java b/src/main/java/io/github/ollama4j/models/request/Auth.java new file mode 100644 index 0000000..70c9c1b --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/request/Auth.java @@ -0,0 +1,10 @@ +package io.github.ollama4j.models.request; + +public abstract class Auth { + /** + * Get authentication header value. + * + * @return authentication header value + */ + public abstract String getAuthHeaderValue(); +} diff --git a/src/main/java/io/github/ollama4j/models/request/BasicAuth.java b/src/main/java/io/github/ollama4j/models/request/BasicAuth.java index f3372a9..c58b240 100644 --- a/src/main/java/io/github/ollama4j/models/request/BasicAuth.java +++ b/src/main/java/io/github/ollama4j/models/request/BasicAuth.java @@ -1,13 +1,24 @@ package io.github.ollama4j.models.request; +import java.util.Base64; + import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; @Data -@NoArgsConstructor @AllArgsConstructor -public class BasicAuth { +public class BasicAuth extends Auth { private String username; private String password; + + /** + * Get basic authentication header value. + * + * @return basic authentication header value (encoded credentials) + */ + public String getAuthHeaderValue() { + final String credentialsToEncode = this.getUsername() + ":" + this.getPassword(); + return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); + } } diff --git a/src/main/java/io/github/ollama4j/models/request/BearerAuth.java b/src/main/java/io/github/ollama4j/models/request/BearerAuth.java new file mode 100644 index 0000000..8236042 --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/request/BearerAuth.java @@ -0,0 +1,19 @@ +package io.github.ollama4j.models.request; + +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class BearerAuth extends Auth { + private String bearerToken; + + /** + * Get authentication header value. + * + * @return authentication header value with bearer token + */ + public String getAuthHeaderValue() { + return "Bearer "+ bearerToken; + } +} diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java index a1a6216..09a3870 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java @@ -30,8 +30,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { private OllamaTokenHandler tokenHandler; - public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { - super(host, basicAuth, requestTimeoutSeconds, verbose); + public OllamaChatEndpointCaller(String host, Auth auth, long requestTimeoutSeconds, boolean verbose) { + super(host, auth, requestTimeoutSeconds, verbose); } @Override diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java index e9d0e0d..1f42ef8 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java @@ -1,26 +1,14 @@ package io.github.ollama4j.models.request; -import io.github.ollama4j.OllamaAPI; -import io.github.ollama4j.exceptions.OllamaBaseException; -import io.github.ollama4j.models.response.OllamaErrorResponse; -import io.github.ollama4j.models.response.OllamaResult; -import io.github.ollama4j.utils.OllamaRequestBody; -import io.github.ollama4j.utils.Utils; -import lombok.Getter; +import java.net.URI; +import java.net.http.HttpRequest; +import java.time.Duration; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.Base64; +import io.github.ollama4j.OllamaAPI; +import lombok.Getter; /** * Abstract helperclass to call the ollama api server. @@ -31,13 +19,13 @@ public abstract class OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class); private final String host; - private final BasicAuth basicAuth; + private final Auth auth; private final long requestTimeoutSeconds; private final boolean verbose; - public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { + public OllamaEndpointCaller(String host, Auth auth, long requestTimeoutSeconds, boolean verbose) { this.host = host; - this.basicAuth = basicAuth; + this.auth = auth; this.requestTimeoutSeconds = requestTimeoutSeconds; this.verbose = verbose; } @@ -58,29 +46,19 @@ public abstract class OllamaEndpointCaller { HttpRequest.newBuilder(uri) .header("Content-Type", "application/json") .timeout(Duration.ofSeconds(this.requestTimeoutSeconds)); - if (isBasicAuthCredentialsSet()) { - requestBuilder.header("Authorization", getBasicAuthHeaderValue()); + if (isAuthCredentialsSet()) { + requestBuilder.header("Authorization", this.auth.getAuthHeaderValue()); } return requestBuilder; } /** - * Get basic authentication header value. + * Check if Auth credentials set. * - * @return basic authentication header value (encoded credentials) + * @return true when Auth credentials set */ - protected String getBasicAuthHeaderValue() { - String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword(); - return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); - } - - /** - * Check if Basic Auth credentials set. - * - * @return true when Basic Auth credentials set - */ - protected boolean isBasicAuthCredentialsSet() { - return this.basicAuth != null; + protected boolean isAuthCredentialsSet() { + return this.auth != null; } } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java index 00b2b12..461ec75 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java @@ -28,7 +28,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { private OllamaGenerateStreamObserver streamObserver; - public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { + public OllamaGenerateEndpointCaller(String host, Auth basicAuth, long requestTimeoutSeconds, boolean verbose) { super(host, basicAuth, requestTimeoutSeconds, verbose); }