Extends OllamaAPI by Chat methods and refactors synchronous Generate API Methods

This commit is contained in:
Markus Klenke 2024-02-09 16:32:52 +00:00
parent bc20468f28
commit 00a3e51a93

View File

@ -2,10 +2,15 @@ package io.github.amithkoujalgi.ollama4j.core;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.*;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaChatRequestCaller;
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateRequestCaller;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import java.io.BufferedReader;
@ -343,7 +348,7 @@ public class OllamaAPI {
throws OllamaBaseException, IOException, InterruptedException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt);
ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSync(ollamaRequestModel);
return generateSyncForOllamaRequestModel(ollamaRequestModel);
}
/**
@ -387,7 +392,7 @@ public class OllamaAPI {
}
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSync(ollamaRequestModel);
return generateSyncForOllamaRequestModel(ollamaRequestModel);
}
/**
@ -411,9 +416,23 @@ public class OllamaAPI {
}
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSync(ollamaRequestModel);
return generateSyncForOllamaRequestModel(ollamaRequestModel);
}
public OllamaResult chat(String model, List<OllamaChatMessage> messages) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException{
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model);
return chat(builder.withMessages(messages).build());
}
public OllamaResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException{
OllamaChatRequestCaller requestCaller = new OllamaChatRequestCaller(host, basicAuth, requestTimeoutSeconds, verbose);
return requestCaller.generateSync(request);
}
// technical private methods //
private static String encodeFileToBase64(File file) throws IOException {
return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
}
@ -436,58 +455,10 @@ public class OllamaAPI {
}
}
private OllamaResult generateSync(OllamaRequestModel ollamaRequestModel)
private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel)
throws OllamaBaseException, IOException, InterruptedException {
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate");
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.POST(
HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)));
HttpRequest request = requestBuilder.build();
if (verbose) logger.info("Asking model: " + ollamaRequestModel);
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
logger.warn("Status code: 404 (Not Found)");
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) {
logger.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
OllamaResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
if (!ollamaResponseModel.isDone()) {
responseBuffer.append(ollamaResponseModel.getResponse());
}
}
}
}
if (statusCode != 200) {
logger.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
if (verbose) logger.info("Model response: " + ollamaResult);
return ollamaResult;
}
OllamaGenerateRequestCaller requestCaller = new OllamaGenerateRequestCaller(host, basicAuth, requestTimeoutSeconds, verbose);
return requestCaller.generateSync(ollamaRequestModel);
}
/**