Merge pull request #98 from csware/bearertoken

Support bearer token
This commit is contained in:
Amith Koujalgi 2025-03-18 20:30:08 +05:30 committed by GitHub
commit ba0444194f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 75 additions and 58 deletions

View File

@ -72,7 +72,7 @@ public class OllamaAPI {
@Setter @Setter
private int maxChatToolCallRetries = 3; private int maxChatToolCallRetries = 3;
private BasicAuth basicAuth; private Auth auth;
private final ToolRegistry toolRegistry = new ToolRegistry(); private final ToolRegistry toolRegistry = new ToolRegistry();
@ -106,7 +106,16 @@ public class OllamaAPI {
* @param password the password * @param password the password
*/ */
public void setBasicAuth(String username, String password) { public void setBasicAuth(String username, String password) {
this.basicAuth = new BasicAuth(username, password); this.auth = new BasicAuth(username, password);
}
/**
* Set Bearer authentication for accessing Ollama server that's behind a reverse-proxy/gateway.
*
* @param bearerToken the Bearer authentication token to provide
*/
public void setBearerAuth(String bearerToken) {
this.auth = new BearerAuth(bearerToken);
} }
/** /**
@ -888,7 +897,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException { public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, verbose);
OllamaChatResult result; OllamaChatResult result;
// add all registered tools to Request // add all registered tools to Request
@ -1094,7 +1103,7 @@ public class OllamaAPI {
* @throws InterruptedException if the thread is interrupted during the request. * @throws InterruptedException if the thread is interrupted during the request.
*/ */
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, verbose);
OllamaResult result; OllamaResult result;
if (streamHandler != null) { if (streamHandler != null) {
ollamaRequestModel.setStream(true); ollamaRequestModel.setStream(true);
@ -1115,28 +1124,18 @@ public class OllamaAPI {
private HttpRequest.Builder getRequestBuilderDefault(URI uri) { private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds)); HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds));
if (isBasicAuthCredentialsSet()) { if (isBasicAuthCredentialsSet()) {
requestBuilder.header("Authorization", getBasicAuthHeaderValue()); requestBuilder.header("Authorization", auth.getAuthHeaderValue());
} }
return requestBuilder; return requestBuilder;
} }
/**
* Get basic authentication header value.
*
* @return basic authentication header value (encoded credentials)
*/
private String getBasicAuthHeaderValue() {
String credentialsToEncode = basicAuth.getUsername() + ":" + basicAuth.getPassword();
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
}
/** /**
* Check if Basic Auth credentials set. * Check if Basic Auth credentials set.
* *
* @return true when Basic Auth credentials set * @return true when Basic Auth credentials set
*/ */
private boolean isBasicAuthCredentialsSet() { private boolean isBasicAuthCredentialsSet() {
return basicAuth != null; return auth != null;
} }
private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException { private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {

View File

@ -0,0 +1,10 @@
package io.github.ollama4j.models.request;
public abstract class Auth {
/**
* Get authentication header value.
*
* @return authentication header value
*/
public abstract String getAuthHeaderValue();
}

View File

@ -1,13 +1,24 @@
package io.github.ollama4j.models.request; package io.github.ollama4j.models.request;
import java.util.Base64;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
@Data @Data
@NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
public class BasicAuth { public class BasicAuth extends Auth {
private String username; private String username;
private String password; private String password;
/**
* Get basic authentication header value.
*
* @return basic authentication header value (encoded credentials)
*/
public String getAuthHeaderValue() {
final String credentialsToEncode = this.getUsername() + ":" + this.getPassword();
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
}
} }

View File

@ -0,0 +1,19 @@
package io.github.ollama4j.models.request;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class BearerAuth extends Auth {
private String bearerToken;
/**
* Get authentication header value.
*
* @return authentication header value with bearer token
*/
public String getAuthHeaderValue() {
return "Bearer "+ bearerToken;
}
}

View File

@ -30,8 +30,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
private OllamaTokenHandler tokenHandler; private OllamaTokenHandler tokenHandler;
public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { public OllamaChatEndpointCaller(String host, Auth auth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose); super(host, auth, requestTimeoutSeconds, verbose);
} }
@Override @Override

View File

@ -1,26 +1,14 @@
package io.github.ollama4j.models.request; package io.github.ollama4j.models.request;
import io.github.ollama4j.OllamaAPI; import java.net.URI;
import io.github.ollama4j.exceptions.OllamaBaseException; import java.net.http.HttpRequest;
import io.github.ollama4j.models.response.OllamaErrorResponse; import java.time.Duration;
import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.utils.OllamaRequestBody;
import io.github.ollama4j.utils.Utils;
import lombok.Getter;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.BufferedReader; import io.github.ollama4j.OllamaAPI;
import java.io.IOException; import lombok.Getter;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;
/** /**
* Abstract helperclass to call the ollama api server. * Abstract helperclass to call the ollama api server.
@ -31,13 +19,13 @@ public abstract class OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
private final String host; private final String host;
private final BasicAuth basicAuth; private final Auth auth;
private final long requestTimeoutSeconds; private final long requestTimeoutSeconds;
private final boolean verbose; private final boolean verbose;
public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { public OllamaEndpointCaller(String host, Auth auth, long requestTimeoutSeconds, boolean verbose) {
this.host = host; this.host = host;
this.basicAuth = basicAuth; this.auth = auth;
this.requestTimeoutSeconds = requestTimeoutSeconds; this.requestTimeoutSeconds = requestTimeoutSeconds;
this.verbose = verbose; this.verbose = verbose;
} }
@ -58,29 +46,19 @@ public abstract class OllamaEndpointCaller {
HttpRequest.newBuilder(uri) HttpRequest.newBuilder(uri)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(this.requestTimeoutSeconds)); .timeout(Duration.ofSeconds(this.requestTimeoutSeconds));
if (isBasicAuthCredentialsSet()) { if (isAuthCredentialsSet()) {
requestBuilder.header("Authorization", getBasicAuthHeaderValue()); requestBuilder.header("Authorization", this.auth.getAuthHeaderValue());
} }
return requestBuilder; return requestBuilder;
} }
/** /**
* Get basic authentication header value. * Check if Auth credentials set.
* *
* @return basic authentication header value (encoded credentials) * @return true when Auth credentials set
*/ */
protected String getBasicAuthHeaderValue() { protected boolean isAuthCredentialsSet() {
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword(); return this.auth != null;
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
}
/**
* Check if Basic Auth credentials set.
*
* @return true when Basic Auth credentials set
*/
protected boolean isBasicAuthCredentialsSet() {
return this.basicAuth != null;
} }
} }

View File

@ -28,7 +28,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
private OllamaGenerateStreamObserver streamObserver; private OllamaGenerateStreamObserver streamObserver;
public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { public OllamaGenerateEndpointCaller(String host, Auth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose); super(host, basicAuth, requestTimeoutSeconds, verbose);
} }