Merge pull request #23 from AgentSchmecker/feature/chat-request-model

Adds feature for /api/chat access via OllamaAPI
This commit is contained in:
Amith Koujalgi 2024-02-12 21:36:33 +05:30 committed by GitHub
commit d716b81342
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 691 additions and 57 deletions

View File

@ -0,0 +1,98 @@
---
sidebar_position: 7
---
# Chat
This API lets you create a conversation with LLMs. Using this API enables you to ask questions to the model including
information using the history of already asked questions and the respective answers.
## Create a new conversation and use chat history to augment follow up questions
```java
public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(OllamaModelType.LLAMA2);
// create first user question
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,"What is the capital of France?")
.build();
// start conversation with model
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
System.out.println("First answer: " + chatResult.getResponse());
// create next userQuestion
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER,"And what is the second largest city?").build();
// "continue" conversation with model
chatResult = ollamaAPI.chat(requestModel);
System.out.println("Second answer: " + chatResult.getResponse());
System.out.println("Chat History: " + chatResult.getChatHistory());
}
}
```
You will get a response similar to:
> First answer: Should be Paris!
>
> Second answer: Marseille.
>
> Chat History:
```json
[ {
"role" : "user",
"content" : "What is the capital of France?",
"images" : [ ]
}, {
"role" : "assistant",
"content" : "Should be Paris!",
"images" : [ ]
}, {
"role" : "user",
"content" : "And what is the second largest city?",
"images" : [ ]
}, {
"role" : "assistant",
"content" : "Marseille.",
"images" : [ ]
} ]
```
## Create a new conversation with individual system prompt
```java
public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(OllamaModelType.LLAMA2);
// create request with system-prompt (overriding the model defaults) and user question
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
.withMessage(OllamaChatMessageRole.USER,"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
// start conversation with model
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
System.out.println(chatResult.getResponse());
}
}
```
You will get a response similar to:
> NI.

View File

@ -2,10 +2,16 @@ package io.github.amithkoujalgi.ollama4j.core;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.*; import io.github.amithkoujalgi.ollama4j.core.models.*;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaChatEndpointCaller;
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateEndpointCaller;
import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import java.io.BufferedReader; import java.io.BufferedReader;
@ -343,7 +349,7 @@ public class OllamaAPI {
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSync(ollamaRequestModel); return generateSyncForOllamaRequestModel(ollamaRequestModel);
} }
/** /**
@ -387,7 +393,7 @@ public class OllamaAPI {
} }
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSync(ollamaRequestModel); return generateSyncForOllamaRequestModel(ollamaRequestModel);
} }
/** /**
@ -411,9 +417,50 @@ public class OllamaAPI {
} }
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSync(ollamaRequestModel); return generateSyncForOllamaRequestModel(ollamaRequestModel);
} }
/**
* Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api
* 'api/chat'.
*
* @param model the ollama model to ask the question to
* @param messages chat history / message stack to send to the model
* @return {@link OllamaChatResult} containing the api response and the message history including the newly aqcuired assistant response.
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaChatResult chat(String model, List<OllamaChatMessage> messages) throws OllamaBaseException, IOException, InterruptedException{
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model);
return chat(builder.withMessages(messages).build());
}
/**
* Ask a question to a model using an {@link OllamaChatRequestModel}. This can be constructed using an {@link OllamaChatRequestBuilder}.
*
* Hint: the {@link OllamaChatRequestModel#getStream()} property is not implemented
*
* @param request request object to be sent to the server
* @return
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
//TODO: implement async way
if(request.isStream()){
throw new UnsupportedOperationException("Streamed chat responses are not implemented yet");
}
OllamaResult result = requestCaller.generateSync(request);
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
}
// technical private methods //
private static String encodeFileToBase64(File file) throws IOException { private static String encodeFileToBase64(File file) throws IOException {
return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath())); return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
} }
@ -436,58 +483,10 @@ public class OllamaAPI {
} }
} }
private OllamaResult generateSync(OllamaRequestModel ollamaRequestModel) private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
long startTime = System.currentTimeMillis(); OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
HttpClient httpClient = HttpClient.newHttpClient(); return requestCaller.generateSync(ollamaRequestModel);
URI uri = URI.create(this.host + "/api/generate");
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.POST(
HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)));
HttpRequest request = requestBuilder.build();
if (verbose) logger.info("Asking model: " + ollamaRequestModel);
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
logger.warn("Status code: 404 (Not Found)");
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) {
logger.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
OllamaResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
if (!ollamaResponseModel.isDone()) {
responseBuffer.append(ollamaResponseModel.getResponse());
}
}
}
}
if (statusCode != 200) {
logger.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
if (verbose) logger.info("Model response: " + ollamaResult);
return ollamaResult;
}
} }
/** /**

View File

@ -1,8 +1,6 @@
package io.github.amithkoujalgi.ollama4j.core.models; package io.github.amithkoujalgi.ollama4j.core.models;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import lombok.Data; import lombok.Data;
@Data @Data

View File

@ -3,12 +3,15 @@ package io.github.amithkoujalgi.ollama4j.core.models;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import lombok.Data; import lombok.Data;
@Data @Data
public class OllamaRequestModel { public class OllamaRequestModel implements OllamaRequestBody{
private String model; private String model;
private String prompt; private String prompt;

View File

@ -0,0 +1,42 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import com.fasterxml.jackson.core.JsonProcessingException;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import java.io.File;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
/**
* Defines a single Message to be used inside a chat request against the ollama /api/chat endpoint.
*
* @see https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
*/
@Data
@AllArgsConstructor
@RequiredArgsConstructor
@NoArgsConstructor
public class OllamaChatMessage {
@NonNull
private OllamaChatMessageRole role;
@NonNull
private String content;
private List<File> images;
@Override
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,19 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import com.fasterxml.jackson.annotation.JsonValue;
/**
* Defines the possible Chat Message roles.
*/
public enum OllamaChatMessageRole {
SYSTEM("system"),
USER("user"),
ASSISTANT("assistant");
@JsonValue
private String roleName;
private OllamaChatMessageRole(String roleName){
this.roleName = roleName;
}
}

View File

@ -0,0 +1,68 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
/**
* Helper class for creating {@link OllamaChatRequestModel} objects using the builder-pattern.
*/
public class OllamaChatRequestBuilder {
private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages){
request = new OllamaChatRequestModel(model, messages);
}
private OllamaChatRequestModel request;
public static OllamaChatRequestBuilder getInstance(String model){
return new OllamaChatRequestBuilder(model, new ArrayList<>());
}
public OllamaChatRequestModel build(){
return request;
}
public void reset(){
request = new OllamaChatRequestModel(request.getModel(), new ArrayList<>());
}
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, File... images){
List<OllamaChatMessage> messages = this.request.getMessages();
messages.add(new OllamaChatMessage(role,content,List.of(images)));
return this;
}
public OllamaChatRequestBuilder withMessages(List<OllamaChatMessage> messages){
this.request.getMessages().addAll(messages);
return this;
}
public OllamaChatRequestBuilder withOptions(Options options){
this.request.setOptions(options);
return this;
}
public OllamaChatRequestBuilder withFormat(String format){
this.request.setFormat(format);
return this;
}
public OllamaChatRequestBuilder withTemplate(String template){
this.request.setTemplate(template);
return this;
}
public OllamaChatRequestBuilder withStreaming(){
this.request.setStream(true);
return this;
}
public OllamaChatRequestBuilder withKeepAlive(String keepAlive){
this.request.setKeepAlive(keepAlive);
return this;
}
}

View File

@ -0,0 +1,48 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import java.util.List;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
/**
* Defines a Request to use against the ollama /api/chat endpoint.
*
* @see https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
*/
@Data
@AllArgsConstructor
@RequiredArgsConstructor
public class OllamaChatRequestModel implements OllamaRequestBody{
@NonNull
private String model;
@NonNull
private List<OllamaChatMessage> messages;
private String format;
private Options options;
private String template;
private boolean stream;
private String keepAlive;
@Override
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,21 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import lombok.Data;
@Data
public class OllamaChatResponseModel {
private String model;
private @JsonProperty("created_at") String createdAt;
private OllamaChatMessage message;
private boolean done;
private List<Integer> context;
private @JsonProperty("total_duration") Long totalDuration;
private @JsonProperty("load_duration") Long loadDuration;
private @JsonProperty("prompt_eval_duration") Long promptEvalDuration;
private @JsonProperty("eval_duration") Long evalDuration;
private @JsonProperty("prompt_eval_count") Integer promptEvalCount;
private @JsonProperty("eval_count") Integer evalCount;
}

View File

@ -0,0 +1,32 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import java.util.List;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
/**
* Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
* {@link OllamaChatMessageRole#ASSISTANT} role.
*/
public class OllamaChatResult extends OllamaResult{
private List<OllamaChatMessage> chatHistory;
public OllamaChatResult(String response, long responseTime, int httpStatusCode,
List<OllamaChatMessage> chatHistory) {
super(response, responseTime, httpStatusCode);
this.chatHistory = chatHistory;
appendAnswerToChatHistory(response);
}
public List<OllamaChatMessage> getChatHistory() {
return chatHistory;
}
private void appendAnswerToChatHistory(String answer){
OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer);
this.chatHistory.add(assistantMessage);
}
}

View File

@ -0,0 +1,44 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
/**
* Specialization class for requests
*/
public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose);
}
@Override
protected String getEndpointSuffix() {
return "/api/chat";
}
@Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
try {
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
responseBuffer.append(ollamaResponseModel.getMessage().getContent());
return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!",e);
return true;
}
}
}

View File

@ -0,0 +1,150 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaErrorResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
/**
* Abstract helperclass to call the ollama api server.
*/
public abstract class OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
private String host;
private BasicAuth basicAuth;
private long requestTimeoutSeconds;
private boolean verbose;
public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
this.host = host;
this.basicAuth = basicAuth;
this.requestTimeoutSeconds = requestTimeoutSeconds;
this.verbose = verbose;
}
protected abstract String getEndpointSuffix();
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
/**
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
*
* @param body POST body payload
* @return result answer given by the assistant
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaResult generateSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{
// Create Request
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + getEndpointSuffix());
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
if (this.verbose) LOG.info("Asking model: " + body.toString());
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)");
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
boolean finished = parseResponseAndAddToBuffer(line,responseBuffer);
if (finished) {
break;
}
}
}
}
if (statusCode != 200) {
LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
if (verbose) LOG.info("Model response: " + ollamaResult);
return ollamaResult;
}
}
/**
* Get default request builder.
*
* @param uri URI to get a HttpRequest.Builder
* @return HttpRequest.Builder
*/
private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder =
HttpRequest.newBuilder(uri)
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(this.requestTimeoutSeconds));
if (isBasicAuthCredentialsSet()) {
requestBuilder.header("Authorization", getBasicAuthHeaderValue());
}
return requestBuilder;
}
/**
* Get basic authentication header value.
*
* @return basic authentication header value (encoded credentials)
*/
private String getBasicAuthHeaderValue() {
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
}
/**
* Check if Basic Auth credentials set.
*
* @return true when Basic Auth credentials set
*/
private boolean isBasicAuthCredentialsSet() {
return this.basicAuth != null;
}
}

View File

@ -0,0 +1,40 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResponseModel;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose);
}
@Override
protected String getEndpointSuffix() {
return "/api/generate";
}
@Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
try {
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
responseBuffer.append(ollamaResponseModel.getResponse());
return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!",e);
return true;
}
}
}

View File

@ -0,0 +1,28 @@
package io.github.amithkoujalgi.ollama4j.core.utils;
import java.net.http.HttpRequest.BodyPublisher;
import java.net.http.HttpRequest.BodyPublishers;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.core.JsonProcessingException;
/**
* Interface to represent a OllamaRequest as HTTP-Request Body via {@link BodyPublishers}.
*/
public interface OllamaRequestBody {
/**
* Transforms the OllamaRequest Object to a JSON Object via Jackson.
*
* @return JSON representation of a OllamaRequest
*/
@JsonIgnore
default BodyPublisher getBodyPublisher(){
try {
return BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(this));
} catch (JsonProcessingException e) {
throw new IllegalArgumentException("Request not Body convertible.",e);
}
}
}

View File

@ -5,6 +5,10 @@ import static org.junit.jupiter.api.Assertions.*;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
@ -118,6 +122,46 @@ class TestRealAPIs {
} }
} }
@Test
@Order(3)
void testChat() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
.withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
.withMessage(OllamaChatMessageRole.USER,"And what is the second larges city?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank());
assertEquals(4,chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
@Order(3)
void testChatWithSystemPrompt() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
.withMessage(OllamaChatMessageRole.USER,"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank());
assertTrue(chatResult.getResponse().startsWith("NI"));
assertEquals(3,chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test @Test
@Order(3) @Order(3)
void testAskModelWithOptionsAndImageFiles() { void testAskModelWithOptionsAndImageFiles() {