Updated all APIs to use getRequestBuilderDefault() method

This commit is contained in:
Amith Koujalgi 2023-12-29 10:25:18 +05:30
parent 0afba7e3e3
commit 657593be09
6 changed files with 152 additions and 100 deletions

View File

@ -0,0 +1,24 @@
---
sidebar_position: 2
---
# Set Basic Authentication
This API lets you set the basic authentication for the Ollama client. This would help in scenarios where
Ollama server would be setup behind a gateway/reverse proxy with basic auth.
After configuring basic authentication, all subsequent requests will include the Basic Auth header.
```java
public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.setBasicAuth("username", "password");
}
}
```

View File

@ -37,8 +37,7 @@ public class OllamaAPI {
private final String host;
private long requestTimeoutSeconds = 3;
private boolean verbose = true;
private String username;
private String password;
private BasicAuth basicAuth;
/**
* Instantiates the Ollama API.
@ -53,6 +52,11 @@ public class OllamaAPI {
}
}
/**
* Set request timeout in seconds. Default is 3 seconds.
*
* @param requestTimeoutSeconds the request timeout in seconds
*/
public void setRequestTimeoutSeconds(long requestTimeoutSeconds) {
this.requestTimeoutSeconds = requestTimeoutSeconds;
}
@ -67,11 +71,13 @@ public class OllamaAPI {
}
/**
* Set basic authentication for accessing Ollama server that's behind a reverse-proxy/gateway.
*
* @param username the username
* @param password the password
*/
public void setBasicAuth(String username, String password) {
this.username = username;
this.password = password;
this.basicAuth = new BasicAuth(username, password);
}
/**
@ -85,11 +91,9 @@ public class OllamaAPI {
HttpRequest httpRequest = null;
try {
httpRequest =
HttpRequest.newBuilder()
.uri(new URI(url))
getRequestBuilderDefault(new URI(url))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.GET()
.build();
} catch (URISyntaxException e) {
@ -117,11 +121,9 @@ public class OllamaAPI {
String url = this.host + "/api/tags";
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest =
HttpRequest.newBuilder()
.uri(new URI(url))
getRequestBuilderDefault(new URI(url))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.GET()
.build();
HttpResponse<String> response =
@ -148,12 +150,10 @@ public class OllamaAPI {
String url = this.host + "/api/pull";
String jsonData = new ModelRequest(modelName).toString();
HttpRequest request =
HttpRequest.newBuilder()
.uri(new URI(url))
getRequestBuilderDefault(new URI(url))
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<InputStream> response =
@ -184,15 +184,13 @@ public class OllamaAPI {
* @return the model details
*/
public ModelDetail getModelDetails(String modelName)
throws IOException, OllamaBaseException, InterruptedException {
throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
String url = this.host + "/api/show";
String jsonData = new ModelRequest(modelName).toString();
HttpRequest request =
HttpRequest.newBuilder()
.uri(URI.create(url))
getRequestBuilderDefault(new URI(url))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
.build();
HttpClient client = HttpClient.newHttpClient();
@ -214,15 +212,13 @@ public class OllamaAPI {
* @param modelFilePath the path to model file that exists on the Ollama server.
*/
public void createModelWithFilePath(String modelName, String modelFilePath)
throws IOException, InterruptedException, OllamaBaseException {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create";
String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString();
HttpRequest request =
HttpRequest.newBuilder()
.uri(URI.create(url))
getRequestBuilderDefault(new URI(url))
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.build();
HttpClient client = HttpClient.newHttpClient();
@ -250,15 +246,13 @@ public class OllamaAPI {
* @param modelFileContents the path to model file that exists on the Ollama server.
*/
public void createModelWithModelFileContents(String modelName, String modelFileContents)
throws IOException, InterruptedException, OllamaBaseException {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create";
String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString();
HttpRequest request =
HttpRequest.newBuilder()
.uri(URI.create(url))
getRequestBuilderDefault(new URI(url))
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.build();
HttpClient client = HttpClient.newHttpClient();
@ -280,20 +274,17 @@ public class OllamaAPI {
* Delete a model from Ollama server.
*
* @param modelName the name of the model to be deleted.
* @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama
* server.
* @param ignoreIfNotPresent ignore errors if the specified model is not present on Ollama server.
*/
public void deleteModel(String modelName, boolean ignoreIfNotPresent)
throws IOException, InterruptedException, OllamaBaseException {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/delete";
String jsonData = new ModelRequest(modelName).toString();
HttpRequest request =
HttpRequest.newBuilder()
.uri(URI.create(url))
getRequestBuilderDefault(new URI(url))
.method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
@ -319,7 +310,8 @@ public class OllamaAPI {
URI uri = URI.create(this.host + "/api/embeddings");
String jsonData = new ModelEmbeddingsRequest(model, prompt).toString();
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri)
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.header("Accept", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData));
HttpRequest request = requestBuilder.build();
@ -340,7 +332,7 @@ public class OllamaAPI {
*
* @param model the ollama model to ask the question to
* @param promptText the prompt/question text
* @return OllamaResult - that includes response text and time taken for response
* @return OllamaResult that includes response text and time taken for response
*/
public OllamaResult ask(String model, String promptText)
throws OllamaBaseException, IOException, InterruptedException {
@ -359,10 +351,11 @@ public class OllamaAPI {
*/
public OllamaAsyncResultCallback askAsync(String model, String promptText) {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText);
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback =
new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel, requestTimeoutSeconds);
new OllamaAsyncResultCallback(
getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
ollamaAsyncResultCallback.start();
return ollamaAsyncResultCallback;
}
@ -374,7 +367,7 @@ public class OllamaAPI {
* @param model the ollama model to ask the question to
* @param promptText the prompt/question text
* @param imageFiles the list of image files to use for the question
* @return OllamaResult - that includes response text and time taken for response
* @return OllamaResult that includes response text and time taken for response
*/
public OllamaResult askWithImageFiles(String model, String promptText, List<File> imageFiles)
throws OllamaBaseException, IOException, InterruptedException {
@ -393,7 +386,7 @@ public class OllamaAPI {
* @param model the ollama model to ask the question to
* @param promptText the prompt/question text
* @param imageURLs the list of image URLs to use for the question
* @return OllamaResult - that includes response text and time taken for response
* @return OllamaResult that includes response text and time taken for response
*/
public OllamaResult askWithImageURLs(String model, String promptText, List<String> imageURLs)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
@ -432,7 +425,8 @@ public class OllamaAPI {
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate");
HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri)
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.POST(
HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)));
@ -455,9 +449,10 @@ public class OllamaAPI {
} else if (statusCode == 401) {
logger.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class);
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
}else {
} else {
OllamaResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
if (!ollamaResponseModel.isDone()) {
@ -467,7 +462,7 @@ public class OllamaAPI {
}
}
if (statusCode != 200) {
logger.error("Status code " + statusCode + " instead 200");
logger.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
@ -476,35 +471,38 @@ public class OllamaAPI {
}
/**
* Get default request builder.
*
* @param uri URI to get a HttpRequest.Builder
* @return HttpRequest.Builder
*/
private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder =
HttpRequest.newBuilder(uri)
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds));
if (basicAuthCredentialsSet()) {
if (isBasicAuthCredentialsSet()) {
requestBuilder.header("Authorization", getBasicAuthHeaderValue());
}
return requestBuilder;
}
/**
* Get basic authentication header value.
*
* @return basic authentication header value (encoded credentials)
*/
private String getBasicAuthHeaderValue() {
String credentialsToEncode = username + ":" + password;
String credentialsToEncode = basicAuth.getUsername() + ":" + basicAuth.getPassword();
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
}
/**
* Check if Basic Auth credentials set.
*
* @return true when Basic Auth credentials set
*/
private boolean basicAuthCredentialsSet() {
if (username != null && password != null) {
return true;
} else {
return false;
}
private boolean isBasicAuthCredentialsSet() {
return basicAuth != null;
}
}

View File

@ -0,0 +1,13 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class BasicAuth {
private String username;
private String password;
}

View File

@ -6,7 +6,6 @@ import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
@ -14,30 +13,44 @@ import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.LinkedList;
import java.util.Queue;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
@Data
@EqualsAndHashCode(callSuper = true)
@SuppressWarnings("unused")
public class OllamaAsyncResultCallback extends Thread {
private final HttpClient client;
private final URI uri;
private final HttpRequest.Builder requestBuilder;
private final OllamaRequestModel ollamaRequestModel;
private final Queue<String> queue = new LinkedList<>();
private String result;
private boolean isDone;
private boolean succeeded;
/**
* -- GETTER -- Returns the status of the request. Indicates if the request was successful or a
* failure. If the request was a failure, the `getResponse()` method will return the error
* message.
*/
@Getter private boolean succeeded;
private long requestTimeoutSeconds;
private int httpStatusCode;
private long responseTime = 0;
/**
* -- GETTER -- Returns the HTTP response status code for the request that was made to Ollama
* server.
*/
@Getter private int httpStatusCode;
/** -- GETTER -- Returns the response time in milliseconds. */
@Getter private long responseTime = 0;
public OllamaAsyncResultCallback(
HttpClient client,
URI uri,
HttpRequest.Builder requestBuilder,
OllamaRequestModel ollamaRequestModel,
long requestTimeoutSeconds) {
this.client = client;
this.requestBuilder = requestBuilder;
this.ollamaRequestModel = ollamaRequestModel;
this.uri = uri;
this.isDone = false;
this.result = "";
this.queue.add("");
@ -46,10 +59,11 @@ public class OllamaAsyncResultCallback extends Thread {
@Override
public void run() {
HttpClient httpClient = HttpClient.newHttpClient();
try {
long startTime = System.currentTimeMillis();
HttpRequest request =
HttpRequest.newBuilder(uri)
requestBuilder
.POST(
HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
@ -57,7 +71,7 @@ public class OllamaAsyncResultCallback extends Thread {
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.build();
HttpResponse<InputStream> response =
client.send(request, HttpResponse.BodyHandlers.ofInputStream());
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
this.httpStatusCode = statusCode;
@ -108,25 +122,6 @@ public class OllamaAsyncResultCallback extends Thread {
return isDone;
}
/**
* Returns the HTTP response status code for the request that was made to Ollama server.
*
* @return int - the status code for the request
*/
public int getHttpStatusCode() {
return httpStatusCode;
}
/**
* Returns the status of the request. Indicates if the request was successful or a failure. If the
* request was a failure, the `getResponse()` method will return the error message.
*
* @return boolean - status
*/
public boolean isSucceeded() {
return succeeded;
}
/**
* Returns the final response when the execution completes. Does not return intermediate results.
*
@ -140,15 +135,6 @@ public class OllamaAsyncResultCallback extends Thread {
return queue;
}
/**
* Returns the response time in milliseconds.
*
* @return long - response time in milliseconds.
*/
public long getResponseTime() {
return responseTime;
}
public void setRequestTimeoutSeconds(long requestTimeoutSeconds) {
this.requestTimeoutSeconds = requestTimeoutSeconds;
}

View File

@ -11,12 +11,13 @@ import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collections;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
class TestMockedAPIs {
@Test
void testMockPullModel() {
void testPullModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
@ -49,7 +50,7 @@ class TestMockedAPIs {
doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
} catch (IOException | OllamaBaseException | InterruptedException e) {
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
@ -62,7 +63,7 @@ class TestMockedAPIs {
doNothing().when(ollamaAPI).deleteModel(model, true);
ollamaAPI.deleteModel(model, true);
verify(ollamaAPI, times(1)).deleteModel(model, true);
} catch (IOException | OllamaBaseException | InterruptedException e) {
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
@ -75,7 +76,7 @@ class TestMockedAPIs {
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
ollamaAPI.getModelDetails(model);
verify(ollamaAPI, times(1)).getModelDetails(model);
} catch (IOException | OllamaBaseException | InterruptedException e) {
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
@ -108,13 +109,43 @@ class TestMockedAPIs {
}
}
@Test
void testAskWithImageFiles() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.askWithImageFiles(model, prompt, Collections.emptyList()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.askWithImageFiles(model, prompt, Collections.emptyList());
verify(ollamaAPI, times(1)).askWithImageFiles(model, prompt, Collections.emptyList());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
void testAskWithImageURLs() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.askWithImageURLs(model, prompt, Collections.emptyList()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.askWithImageURLs(model, prompt, Collections.emptyList());
verify(ollamaAPI, times(1)).askWithImageURLs(model, prompt, Collections.emptyList());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
@Test
void testAskAsync() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
when(ollamaAPI.askAsync(model, prompt))
.thenReturn(new OllamaAsyncResultCallback(null, null, null, 3));
.thenReturn(new OllamaAsyncResultCallback(null, null, 3));
ollamaAPI.askAsync(model, prompt);
verify(ollamaAPI, times(1)).askAsync(model, prompt);
}