Updated default request timeout to 10 seconds

This commit is contained in:
Amith Koujalgi 2024-05-14 10:27:56 +05:30
parent e4e717b747
commit 04124cf978
3 changed files with 628 additions and 638 deletions

View File

@ -112,7 +112,7 @@ You will get a response similar to:
## Use a simple Console Output Stream Handler ## Use a simple Console Output Stream Handler
``` ```java
import io.github.amithkoujalgi.ollama4j.core.impl.ConsoleOutputStreamHandler; import io.github.amithkoujalgi.ollama4j.core.impl.ConsoleOutputStreamHandler;
public class Main { public class Main {

View File

@ -1,5 +1,15 @@
package io.github.amithkoujalgi.ollama4j.core.models.request; package io.github.amithkoujalgi.ollama4j.core.models.request;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaErrorResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -12,17 +22,6 @@ import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.Base64; import java.util.Base64;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaErrorResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
/** /**
* Abstract helperclass to call the ollama api server. * Abstract helperclass to call the ollama api server.
*/ */
@ -52,104 +51,102 @@ public abstract class OllamaEndpointCaller {
* *
* @param body POST body payload * @param body POST body payload
* @return result answer given by the assistant * @return result answer given by the assistant
* @throws OllamaBaseException any response code than 200 has been returned * @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen * @throws InterruptedException in case the server is not reachable or network issues happen
*/ */
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{ public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
// Create Request // Create Request
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + getEndpointSuffix()); URI uri = URI.create(this.host + getEndpointSuffix());
HttpRequest.Builder requestBuilder = HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri) getRequestBuilderDefault(uri)
.POST( .POST(
body.getBodyPublisher()); body.getBodyPublisher());
HttpRequest request = requestBuilder.build(); HttpRequest request = requestBuilder.build();
if (this.verbose) LOG.info("Asking model: " + body.toString()); if (this.verbose) LOG.info("Asking model: " + body.toString());
HttpResponse<InputStream> response = HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
InputStream responseBodyStream = response.body(); InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder(); StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader = try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
if (statusCode == 404) { if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)"); LOG.warn("Status code: 404 (Not Found)");
OllamaErrorResponseModel ollamaResponseModel = OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) { } else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)"); LOG.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponseModel ollamaResponseModel = OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper() Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) { } else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)"); LOG.warn("Status code: 400 (Bad Request)");
OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponseModel.class); OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else { } else {
boolean finished = parseResponseAndAddToBuffer(line,responseBuffer); boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
if (finished) { if (finished) {
break; break;
}
}
} }
} }
}
}
if (statusCode != 200) { if (statusCode != 200) {
LOG.error("Status code " + statusCode); LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString()); throw new OllamaBaseException(responseBuffer.toString());
} else { } else {
long endTime = System.currentTimeMillis(); long endTime = System.currentTimeMillis();
OllamaResult ollamaResult = OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
if (verbose) LOG.info("Model response: " + ollamaResult); if (verbose) LOG.info("Model response: " + ollamaResult);
return ollamaResult; return ollamaResult;
}
} }
}
/** /**
* Get default request builder. * Get default request builder.
* *
* @param uri URI to get a HttpRequest.Builder * @param uri URI to get a HttpRequest.Builder
* @return HttpRequest.Builder * @return HttpRequest.Builder
*/ */
private HttpRequest.Builder getRequestBuilderDefault(URI uri) { private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder = HttpRequest.Builder requestBuilder =
HttpRequest.newBuilder(uri) HttpRequest.newBuilder(uri)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(this.requestTimeoutSeconds)); .timeout(Duration.ofSeconds(this.requestTimeoutSeconds));
if (isBasicAuthCredentialsSet()) { if (isBasicAuthCredentialsSet()) {
requestBuilder.header("Authorization", getBasicAuthHeaderValue()); requestBuilder.header("Authorization", getBasicAuthHeaderValue());
}
return requestBuilder;
} }
return requestBuilder;
}
/** /**
* Get basic authentication header value. * Get basic authentication header value.
* *
* @return basic authentication header value (encoded credentials) * @return basic authentication header value (encoded credentials)
*/ */
private String getBasicAuthHeaderValue() { private String getBasicAuthHeaderValue() {
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword(); String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
} }
/** /**
* Check if Basic Auth credentials set. * Check if Basic Auth credentials set.
* *
* @return true when Basic Auth credentials set * @return true when Basic Auth credentials set
*/ */
private boolean isBasicAuthCredentialsSet() { private boolean isBasicAuthCredentialsSet() {
return this.basicAuth != null; return this.basicAuth != null;
} }
} }