diff --git a/docs/docs/apis-ask/chat.md b/docs/docs/apis-ask/chat.md new file mode 100644 index 0000000..00cda08 --- /dev/null +++ b/docs/docs/apis-ask/chat.md @@ -0,0 +1,98 @@ +--- +sidebar_position: 7 +--- + +# Chat + +This API lets you create a conversation with LLMs. Using this API enables you to ask questions to the model including +information using the history of already asked questions and the respective answers. + +## Create a new conversation and use chat history to augment follow up questions + +```java +public class Main { + + public static void main(String[] args) { + + String host = "http://localhost:11434/"; + + OllamaAPI ollamaAPI = new OllamaAPI(host); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(OllamaModelType.LLAMA2); + + // create first user question + OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,"What is the capital of France?") + .build(); + + // start conversation with model + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + + System.out.println("First answer: " + chatResult.getResponse()); + + // create next userQuestion + requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER,"And what is the second largest city?").build(); + + // "continue" conversation with model + chatResult = ollamaAPI.chat(requestModel); + + System.out.println("Second answer: " + chatResult.getResponse()); + + System.out.println("Chat History: " + chatResult.getChatHistory()); + } +} + +``` +You will get a response similar to: + +> First answer: Should be Paris! +> +> Second answer: Marseille. +> +> Chat History: + +```json +[ { + "role" : "user", + "content" : "What is the capital of France?", + "images" : [ ] + }, { + "role" : "assistant", + "content" : "Should be Paris!", + "images" : [ ] + }, { + "role" : "user", + "content" : "And what is the second largest city?", + "images" : [ ] + }, { + "role" : "assistant", + "content" : "Marseille.", + "images" : [ ] + } ] +``` + +## Create a new conversation with individual system prompt +```java +public class Main { + + public static void main(String[] args) { + + String host = "http://localhost:11434/"; + + OllamaAPI ollamaAPI = new OllamaAPI(host); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(OllamaModelType.LLAMA2); + + // create request with system-prompt (overriding the model defaults) and user question + OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!") + .withMessage(OllamaChatMessageRole.USER,"What is the capital of France? And what's France's connection with Mona Lisa?") + .build(); + + // start conversation with model + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + + System.out.println(chatResult.getResponse()); + } +} + +``` +You will get a response similar to: + +> NI. \ No newline at end of file diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 17fddc0..a0212d7 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -2,10 +2,16 @@ package io.github.amithkoujalgi.ollama4j.core; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.*; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest; +import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaChatEndpointCaller; +import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateEndpointCaller; import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import java.io.BufferedReader; @@ -343,7 +349,7 @@ public class OllamaAPI { throws OllamaBaseException, IOException, InterruptedException { OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSync(ollamaRequestModel); + return generateSyncForOllamaRequestModel(ollamaRequestModel); } /** @@ -387,7 +393,7 @@ public class OllamaAPI { } OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSync(ollamaRequestModel); + return generateSyncForOllamaRequestModel(ollamaRequestModel); } /** @@ -411,9 +417,50 @@ public class OllamaAPI { } OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSync(ollamaRequestModel); + return generateSyncForOllamaRequestModel(ollamaRequestModel); } + + + /** + * Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api + * 'api/chat'. + * + * @param model the ollama model to ask the question to + * @param messages chat history / message stack to send to the model + * @return {@link OllamaChatResult} containing the api response and the message history including the newly aqcuired assistant response. + * @throws OllamaBaseException any response code than 200 has been returned + * @throws IOException in case the responseStream can not be read + * @throws InterruptedException in case the server is not reachable or network issues happen + */ + public OllamaChatResult chat(String model, List messages) throws OllamaBaseException, IOException, InterruptedException{ + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model); + return chat(builder.withMessages(messages).build()); + } + + /** + * Ask a question to a model using an {@link OllamaChatRequestModel}. This can be constructed using an {@link OllamaChatRequestBuilder}. + * + * Hint: the {@link OllamaChatRequestModel#getStream()} property is not implemented + * + * @param request request object to be sent to the server + * @return + * @throws OllamaBaseException any response code than 200 has been returned + * @throws IOException in case the responseStream can not be read + * @throws InterruptedException in case the server is not reachable or network issues happen + */ + public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{ + OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); + //TODO: implement async way + if(request.isStream()){ + throw new UnsupportedOperationException("Streamed chat responses are not implemented yet"); + } + OllamaResult result = requestCaller.generateSync(request); + return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); + } + + // technical private methods // + private static String encodeFileToBase64(File file) throws IOException { return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath())); } @@ -436,58 +483,10 @@ public class OllamaAPI { } } - private OllamaResult generateSync(OllamaRequestModel ollamaRequestModel) + private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel) throws OllamaBaseException, IOException, InterruptedException { - long startTime = System.currentTimeMillis(); - HttpClient httpClient = HttpClient.newHttpClient(); - URI uri = URI.create(this.host + "/api/generate"); - HttpRequest.Builder requestBuilder = - getRequestBuilderDefault(uri) - .POST( - HttpRequest.BodyPublishers.ofString( - Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))); - HttpRequest request = requestBuilder.build(); - if (verbose) logger.info("Asking model: " + ollamaRequestModel); - HttpResponse response = - httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); - int statusCode = response.statusCode(); - InputStream responseBodyStream = response.body(); - StringBuilder responseBuffer = new StringBuilder(); - try (BufferedReader reader = - new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { - String line; - while ((line = reader.readLine()) != null) { - if (statusCode == 404) { - logger.warn("Status code: 404 (Not Found)"); - OllamaErrorResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); - responseBuffer.append(ollamaResponseModel.getError()); - } else if (statusCode == 401) { - logger.warn("Status code: 401 (Unauthorized)"); - OllamaErrorResponseModel ollamaResponseModel = - Utils.getObjectMapper() - .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); - responseBuffer.append(ollamaResponseModel.getError()); - } else { - OllamaResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); - if (!ollamaResponseModel.isDone()) { - responseBuffer.append(ollamaResponseModel.getResponse()); - } - } - } - } - - if (statusCode != 200) { - logger.error("Status code " + statusCode); - throw new OllamaBaseException(responseBuffer.toString()); - } else { - long endTime = System.currentTimeMillis(); - OllamaResult ollamaResult = - new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); - if (verbose) logger.info("Model response: " + ollamaResult); - return ollamaResult; - } + OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); + return requestCaller.generateSync(ollamaRequestModel); } /** diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaErrorResponseModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaErrorResponseModel.java index 26fc82b..be3d8e4 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaErrorResponseModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaErrorResponseModel.java @@ -1,8 +1,6 @@ package io.github.amithkoujalgi.ollama4j.core.models; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; -import java.util.List; import lombok.Data; @Data diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java index a2507a6..9c88698 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java @@ -3,12 +3,15 @@ package io.github.amithkoujalgi.ollama4j.core.models; import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; import com.fasterxml.jackson.core.JsonProcessingException; + +import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; + import java.util.List; import java.util.Map; import lombok.Data; @Data -public class OllamaRequestModel { +public class OllamaRequestModel implements OllamaRequestBody{ private String model; private String prompt; diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatMessage.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatMessage.java new file mode 100644 index 0000000..b877ddf --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatMessage.java @@ -0,0 +1,42 @@ +package io.github.amithkoujalgi.ollama4j.core.models.chat; + +import com.fasterxml.jackson.core.JsonProcessingException; +import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; + +import java.io.File; +import java.util.List; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; + +/** + * Defines a single Message to be used inside a chat request against the ollama /api/chat endpoint. + * + * @see https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion + */ +@Data +@AllArgsConstructor +@RequiredArgsConstructor +@NoArgsConstructor +public class OllamaChatMessage { + + @NonNull + private OllamaChatMessageRole role; + + @NonNull + private String content; + + private List images; + + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatMessageRole.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatMessageRole.java new file mode 100644 index 0000000..cbecb00 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatMessageRole.java @@ -0,0 +1,19 @@ +package io.github.amithkoujalgi.ollama4j.core.models.chat; + +import com.fasterxml.jackson.annotation.JsonValue; + +/** + * Defines the possible Chat Message roles. + */ +public enum OllamaChatMessageRole { + SYSTEM("system"), + USER("user"), + ASSISTANT("assistant"); + + @JsonValue + private String roleName; + + private OllamaChatMessageRole(String roleName){ + this.roleName = roleName; + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java new file mode 100644 index 0000000..af28b7b --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java @@ -0,0 +1,68 @@ +package io.github.amithkoujalgi.ollama4j.core.models.chat; + +import java.io.File; +import java.util.ArrayList; +import java.util.List; + +import io.github.amithkoujalgi.ollama4j.core.utils.Options; + +/** + * Helper class for creating {@link OllamaChatRequestModel} objects using the builder-pattern. + */ +public class OllamaChatRequestBuilder { + + private OllamaChatRequestBuilder(String model, List messages){ + request = new OllamaChatRequestModel(model, messages); + } + + private OllamaChatRequestModel request; + + public static OllamaChatRequestBuilder getInstance(String model){ + return new OllamaChatRequestBuilder(model, new ArrayList<>()); + } + + public OllamaChatRequestModel build(){ + return request; + } + + public void reset(){ + request = new OllamaChatRequestModel(request.getModel(), new ArrayList<>()); + } + + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, File... images){ + List messages = this.request.getMessages(); + messages.add(new OllamaChatMessage(role,content,List.of(images))); + return this; + } + + public OllamaChatRequestBuilder withMessages(List messages){ + this.request.getMessages().addAll(messages); + return this; + } + + public OllamaChatRequestBuilder withOptions(Options options){ + this.request.setOptions(options); + return this; + } + + public OllamaChatRequestBuilder withFormat(String format){ + this.request.setFormat(format); + return this; + } + + public OllamaChatRequestBuilder withTemplate(String template){ + this.request.setTemplate(template); + return this; + } + + public OllamaChatRequestBuilder withStreaming(){ + this.request.setStream(true); + return this; + } + + public OllamaChatRequestBuilder withKeepAlive(String keepAlive){ + this.request.setKeepAlive(keepAlive); + return this; + } + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java new file mode 100644 index 0000000..516e00b --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java @@ -0,0 +1,48 @@ +package io.github.amithkoujalgi.ollama4j.core.models.chat; + +import java.util.List; + +import com.fasterxml.jackson.core.JsonProcessingException; + +import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; +import io.github.amithkoujalgi.ollama4j.core.utils.Options; + +import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; + +/** + * Defines a Request to use against the ollama /api/chat endpoint. + * + * @see https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion + */ +@Data +@AllArgsConstructor +@RequiredArgsConstructor +public class OllamaChatRequestModel implements OllamaRequestBody{ + + @NonNull + private String model; + + @NonNull + private List messages; + + private String format; + private Options options; + private String template; + private boolean stream; + private String keepAlive; + + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + +} \ No newline at end of file diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResponseModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResponseModel.java new file mode 100644 index 0000000..da69dfe --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResponseModel.java @@ -0,0 +1,21 @@ +package io.github.amithkoujalgi.ollama4j.core.models.chat; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import lombok.Data; + +@Data +public class OllamaChatResponseModel { + private String model; + private @JsonProperty("created_at") String createdAt; + private OllamaChatMessage message; + private boolean done; + private List context; + private @JsonProperty("total_duration") Long totalDuration; + private @JsonProperty("load_duration") Long loadDuration; + private @JsonProperty("prompt_eval_duration") Long promptEvalDuration; + private @JsonProperty("eval_duration") Long evalDuration; + private @JsonProperty("prompt_eval_count") Integer promptEvalCount; + private @JsonProperty("eval_count") Integer evalCount; +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResult.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResult.java new file mode 100644 index 0000000..6ac6578 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResult.java @@ -0,0 +1,32 @@ +package io.github.amithkoujalgi.ollama4j.core.models.chat; + +import java.util.List; + +import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; + +/** + * Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the + * {@link OllamaChatMessageRole#ASSISTANT} role. + */ +public class OllamaChatResult extends OllamaResult{ + + private List chatHistory; + + public OllamaChatResult(String response, long responseTime, int httpStatusCode, + List chatHistory) { + super(response, responseTime, httpStatusCode); + this.chatHistory = chatHistory; + appendAnswerToChatHistory(response); + } + + public List getChatHistory() { + return chatHistory; + } + + private void appendAnswerToChatHistory(String answer){ + OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer); + this.chatHistory.add(assistantMessage); + } + + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java new file mode 100644 index 0000000..eb06c37 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java @@ -0,0 +1,44 @@ +package io.github.amithkoujalgi.ollama4j.core.models.request; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.JsonProcessingException; + +import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; + +/** + * Specialization class for requests + */ +public class OllamaChatEndpointCaller extends OllamaEndpointCaller{ + + private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); + + public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { + super(host, basicAuth, requestTimeoutSeconds, verbose); + } + + @Override + protected String getEndpointSuffix() { + return "/api/chat"; + } + + @Override + protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { + try { + OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); + responseBuffer.append(ollamaResponseModel.getMessage().getContent()); + return ollamaResponseModel.isDone(); + } catch (JsonProcessingException e) { + LOG.error("Error parsing the Ollama chat response!",e); + return true; + } + } + + + + + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java new file mode 100644 index 0000000..134056f --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java @@ -0,0 +1,150 @@ +package io.github.amithkoujalgi.ollama4j.core.models.request; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Base64; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; +import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; +import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaErrorResponseModel; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; +import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; + +/** + * Abstract helperclass to call the ollama api server. + */ +public abstract class OllamaEndpointCaller { + + private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class); + + private String host; + private BasicAuth basicAuth; + private long requestTimeoutSeconds; + private boolean verbose; + + public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { + this.host = host; + this.basicAuth = basicAuth; + this.requestTimeoutSeconds = requestTimeoutSeconds; + this.verbose = verbose; + } + + protected abstract String getEndpointSuffix(); + + protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); + + + /** + * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response. + * + * @param body POST body payload + * @return result answer given by the assistant + * @throws OllamaBaseException any response code than 200 has been returned + * @throws IOException in case the responseStream can not be read + * @throws InterruptedException in case the server is not reachable or network issues happen + */ + public OllamaResult generateSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{ + + // Create Request + long startTime = System.currentTimeMillis(); + HttpClient httpClient = HttpClient.newHttpClient(); + URI uri = URI.create(this.host + getEndpointSuffix()); + HttpRequest.Builder requestBuilder = + getRequestBuilderDefault(uri) + .POST( + body.getBodyPublisher()); + HttpRequest request = requestBuilder.build(); + if (this.verbose) LOG.info("Asking model: " + body.toString()); + HttpResponse response = + httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + + + int statusCode = response.statusCode(); + InputStream responseBodyStream = response.body(); + StringBuilder responseBuffer = new StringBuilder(); + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + if (statusCode == 404) { + LOG.warn("Status code: 404 (Not Found)"); + OllamaErrorResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else if (statusCode == 401) { + LOG.warn("Status code: 401 (Unauthorized)"); + OllamaErrorResponseModel ollamaResponseModel = + Utils.getObjectMapper() + .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); + responseBuffer.append(ollamaResponseModel.getError()); + } else { + boolean finished = parseResponseAndAddToBuffer(line,responseBuffer); + if (finished) { + break; + } + } + } + } + + if (statusCode != 200) { + LOG.error("Status code " + statusCode); + throw new OllamaBaseException(responseBuffer.toString()); + } else { + long endTime = System.currentTimeMillis(); + OllamaResult ollamaResult = + new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); + if (verbose) LOG.info("Model response: " + ollamaResult); + return ollamaResult; + } + } + + /** + * Get default request builder. + * + * @param uri URI to get a HttpRequest.Builder + * @return HttpRequest.Builder + */ + private HttpRequest.Builder getRequestBuilderDefault(URI uri) { + HttpRequest.Builder requestBuilder = + HttpRequest.newBuilder(uri) + .header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(this.requestTimeoutSeconds)); + if (isBasicAuthCredentialsSet()) { + requestBuilder.header("Authorization", getBasicAuthHeaderValue()); + } + return requestBuilder; + } + + /** + * Get basic authentication header value. + * + * @return basic authentication header value (encoded credentials) + */ + private String getBasicAuthHeaderValue() { + String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword(); + return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); + } + + /** + * Check if Basic Auth credentials set. + * + * @return true when Basic Auth credentials set + */ + private boolean isBasicAuthCredentialsSet() { + return this.basicAuth != null; + } + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java new file mode 100644 index 0000000..8d54db3 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java @@ -0,0 +1,40 @@ +package io.github.amithkoujalgi.ollama4j.core.models.request; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.JsonProcessingException; + +import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaResponseModel; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; + +public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ + + private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class); + + public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { + super(host, basicAuth, requestTimeoutSeconds, verbose); + } + + @Override + protected String getEndpointSuffix() { + return "/api/generate"; + } + + @Override + protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { + try { + OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); + responseBuffer.append(ollamaResponseModel.getResponse()); + return ollamaResponseModel.isDone(); + } catch (JsonProcessingException e) { + LOG.error("Error parsing the Ollama chat response!",e); + return true; + } + } + + + + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/OllamaRequestBody.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/OllamaRequestBody.java new file mode 100644 index 0000000..f787cee --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/OllamaRequestBody.java @@ -0,0 +1,28 @@ +package io.github.amithkoujalgi.ollama4j.core.utils; + +import java.net.http.HttpRequest.BodyPublisher; +import java.net.http.HttpRequest.BodyPublishers; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.core.JsonProcessingException; + +/** + * Interface to represent a OllamaRequest as HTTP-Request Body via {@link BodyPublishers}. + */ +public interface OllamaRequestBody { + + /** + * Transforms the OllamaRequest Object to a JSON Object via Jackson. + * + * @return JSON representation of a OllamaRequest + */ + @JsonIgnore + default BodyPublisher getBodyPublisher(){ + try { + return BodyPublishers.ofString( + Utils.getObjectMapper().writeValueAsString(this)); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Request not Body convertible.",e); + } + } +} diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java index ac06204..284a52d 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java @@ -5,6 +5,10 @@ import static org.junit.jupiter.api.Assertions.*; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import java.io.File; import java.io.IOException; @@ -118,6 +122,46 @@ class TestRealAPIs { } } + @Test + @Order(3) + void testChat() { + testEndpointReachability(); + try { + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); + OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?") + .withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!") + .withMessage(OllamaChatMessageRole.USER,"And what is the second larges city?") + .build(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); + assertFalse(chatResult.getResponse().isBlank()); + assertEquals(4,chatResult.getChatHistory().size()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Test + @Order(3) + void testChatWithSystemPrompt() { + testEndpointReachability(); + try { + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); + OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!") + .withMessage(OllamaChatMessageRole.USER,"What is the capital of France? And what's France's connection with Mona Lisa?") + .build(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); + assertFalse(chatResult.getResponse().isBlank()); + assertTrue(chatResult.getResponse().startsWith("NI")); + assertEquals(3,chatResult.getChatHistory().size()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + @Test @Order(3) void testAskModelWithOptionsAndImageFiles() {