Compare commits

..

18 Commits

Author SHA1 Message Date
amithkoujalgi
31f8302849 [maven-release-plugin] prepare release v1.0.54 2024-02-18 05:22:08 +00:00
Amith Koujalgi
6487756764 Merge pull request #28 from AgentSchmecker/feature/advanced_params_for_generate
Adds advanced request parameters to chat/generate requests
2024-02-18 10:51:04 +05:30
Markus Klenke
0f414f71a3 Changes isEmpty method for BooleanToJsonFormatFlagSerializer to override non deprecated supermethod 2024-02-16 16:01:18 +00:00
Markus Klenke
2b700fdad8 Adds missing pom dependency for JSON comparison tests 2024-02-16 15:58:48 +00:00
Markus Klenke
06c5daa253 Adds additional properties to chat and generate requests 2024-02-16 15:57:48 +00:00
Markus Klenke
91aab6cbd1 Fixes recursive call in non streamed chat API 2024-02-16 15:57:14 +00:00
Markus Klenke
f38a00ebdc Fixes BooleanToJsonFormatFlagSerializer 2024-02-16 15:56:32 +00:00
Markus Klenke
0f73ea75ab Removes unnecessary serialize method of Serializer 2024-02-16 15:56:02 +00:00
Markus Klenke
8fe869afdb Adds additional request properties and refactors common request fields to OllamaCommonRequestModel 2024-02-16 13:15:24 +00:00
amithkoujalgi
2d274c4f5b [maven-release-plugin] prepare for next development iteration 2024-02-16 04:42:42 +00:00
amithkoujalgi
713a3239a4 [maven-release-plugin] prepare release v1.0.53 2024-02-16 04:42:40 +00:00
Amith Koujalgi
a9e7958d44 Merge pull request #26 from AgentSchmecker/main
Adds streaming functionality for chat
2024-02-16 10:11:32 +05:30
Markus Klenke
f38e84053f Adds documentation for streamed chat API call 2024-02-14 16:45:46 +00:00
Markus Klenke
7eb16b7ba0 Merge pull request #1 from AgentSchmecker/feature/streaming-chat
Adds streaming functionality for chat
2024-02-14 16:30:51 +01:00
amithkoujalgi
5a3889d8ee [maven-release-plugin] prepare for next development iteration 2024-02-14 13:44:30 +00:00
Markus Klenke
e9621f054d Adds integration test for chat streaming API 2024-02-13 18:11:59 +00:00
Markus Klenke
b41b62220c Adds chat with stream functionality in OllamaAPI 2024-02-13 17:59:27 +00:00
Markus Klenke
c89440cbca Adds OllamaStream handling 2024-02-13 17:56:07 +00:00
21 changed files with 554 additions and 112 deletions

View File

@@ -69,6 +69,41 @@ You will get a response similar to:
} ]
```
## Create a conversation where the answer is streamed
```java
public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
// define a handler (Consumer<String>)
OllamaStreamHandler streamHandler = (s) -> {
System.out.println(s);
};
OllamaChatResult chatResult = ollamaAPI.chat(requestModel,streamHandler);
}
}
```
You will get a response similar to:
> The
> The capital
> The capital of
> The capital of France
> The capital of France is
> The capital of France is Paris
> The capital of France is Paris.
## Create a new conversation with individual system prompt
```java
public class Main {

10
pom.xml
View File

@@ -4,7 +4,7 @@
<groupId>io.github.amithkoujalgi</groupId>
<artifactId>ollama4j</artifactId>
<version>1.0.52</version>
<version>1.0.54</version>
<name>Ollama4j</name>
<description>Java library for interacting with Ollama API.</description>
@@ -39,7 +39,7 @@
<connection>scm:git:git@github.com:amithkoujalgi/ollama4j.git</connection>
<developerConnection>scm:git:https://github.com/amithkoujalgi/ollama4j.git</developerConnection>
<url>https://github.com/amithkoujalgi/ollama4j</url>
<tag>v1.0.52</tag>
<tag>v1.0.54</tag>
</scm>
<build>
@@ -174,6 +174,12 @@
<version>4.1.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20240205</version>
<scope>test</scope>
</dependency>
</dependencies>
<distributionManagement>

View File

@@ -6,6 +6,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest;
@@ -345,7 +346,7 @@ public class OllamaAPI {
*/
public OllamaResult generate(String model, String prompt, Options options)
throws OllamaBaseException, IOException, InterruptedException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt);
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel);
}
@@ -360,7 +361,7 @@ public class OllamaAPI {
* @return the ollama async result callback handle
*/
public OllamaAsyncResultCallback generateAsync(String model, String prompt) {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt);
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback =
@@ -389,7 +390,7 @@ public class OllamaAPI {
for (File imageFile : imageFiles) {
images.add(encodeFileToBase64(imageFile));
}
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images);
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel);
}
@@ -413,7 +414,7 @@ public class OllamaAPI {
for (String imageURL : imageURLs) {
images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
}
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images);
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel);
}
@@ -448,12 +449,31 @@ public class OllamaAPI {
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{
return chat(request,null);
}
/**
* Ask a question to a model using an {@link OllamaChatRequestModel}. This can be constructed using an {@link OllamaChatRequestBuilder}.
*
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
*
* @param request request object to be sent to the server
* @param streamHandler callback handler to handle the last message from stream (caution: all previous messages from stream will be concatenated)
* @return
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaChatResult chat(OllamaChatRequestModel request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException{
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
//TODO: implement async way
if(request.isStream()){
throw new UnsupportedOperationException("Streamed chat responses are not implemented yet");
OllamaResult result;
if(streamHandler != null){
request.setStream(true);
result = requestCaller.call(request, streamHandler);
}
else {
result = requestCaller.callSync(request);
}
OllamaResult result = requestCaller.generateSync(request);
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
}
@@ -467,10 +487,10 @@ public class OllamaAPI {
return Base64.getEncoder().encodeToString(bytes);
}
private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel)
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel)
throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
return requestCaller.generateSync(ollamaRequestModel);
return requestCaller.callSync(ollamaRequestModel);
}
/**

View File

@@ -0,0 +1,7 @@
package io.github.amithkoujalgi.ollama4j.core;
import java.util.function.Consumer;
public interface OllamaStreamHandler extends Consumer<String>{
void accept(String message);
}

View File

@@ -1,6 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import java.io.BufferedReader;
import java.io.IOException;
@@ -22,7 +24,7 @@ import lombok.Getter;
@SuppressWarnings("unused")
public class OllamaAsyncResultCallback extends Thread {
private final HttpRequest.Builder requestBuilder;
private final OllamaRequestModel ollamaRequestModel;
private final OllamaGenerateRequestModel ollamaRequestModel;
private final Queue<String> queue = new LinkedList<>();
private String result;
private boolean isDone;
@@ -47,7 +49,7 @@ public class OllamaAsyncResultCallback extends Thread {
public OllamaAsyncResultCallback(
HttpRequest.Builder requestBuilder,
OllamaRequestModel ollamaRequestModel,
OllamaGenerateRequestModel ollamaRequestModel,
long requestTimeoutSeconds) {
this.requestBuilder = requestBuilder;
this.ollamaRequestModel = ollamaRequestModel;
@@ -87,8 +89,8 @@ public class OllamaAsyncResultCallback extends Thread {
queue.add(ollamaResponseModel.getError());
responseBuffer.append(ollamaResponseModel.getError());
} else {
OllamaResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
OllamaGenerateResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
queue.add(ollamaResponseModel.getResponse());
if (!ollamaResponseModel.isDone()) {
responseBuffer.append(ollamaResponseModel.getResponse());

View File

@@ -0,0 +1,35 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import java.util.Map;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.github.amithkoujalgi.ollama4j.core.utils.BooleanToJsonFormatFlagSerializer;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import lombok.Data;
@Data
@JsonInclude(JsonInclude.Include.NON_NULL)
public abstract class OllamaCommonRequestModel {
protected String model;
@JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class)
@JsonProperty(value = "format")
protected Boolean returnFormatJson;
protected Map<String, Object> options;
protected String template;
protected boolean stream;
@JsonProperty(value = "keep_alive")
protected String keepAlive;
public String toString() {
try {
return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@@ -1,39 +0,0 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import java.util.List;
import java.util.Map;
import lombok.Data;
@Data
public class OllamaRequestModel implements OllamaRequestBody{
private String model;
private String prompt;
private List<String> images;
private Map<String, Object> options;
public OllamaRequestModel(String model, String prompt) {
this.model = model;
this.prompt = prompt;
}
public OllamaRequestModel(String model, String prompt, List<String> images) {
this.model = model;
this.prompt = prompt;
this.images = images;
}
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@@ -83,12 +83,12 @@ public class OllamaChatRequestBuilder {
}
public OllamaChatRequestBuilder withOptions(Options options){
this.request.setOptions(options);
this.request.setOptions(options.getOptionsMap());
return this;
}
public OllamaChatRequestBuilder withFormat(String format){
this.request.setFormat(format);
public OllamaChatRequestBuilder withGetJsonResponse(){
this.request.setReturnFormatJson(true);
return this;
}

View File

@@ -1,47 +1,39 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import java.util.List;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.Getter;
import lombok.Setter;
/**
* Defines a Request to use against the ollama /api/chat endpoint.
*
* @see <a
* href="https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate
* Chat Completion</a>
* @see <a href=
* "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate
* Chat Completion</a>
*/
@Data
@AllArgsConstructor
@RequiredArgsConstructor
public class OllamaChatRequestModel implements OllamaRequestBody {
@Getter
@Setter
public class OllamaChatRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody {
@NonNull private String model;
private List<OllamaChatMessage> messages;
@NonNull private List<OllamaChatMessage> messages;
public OllamaChatRequestModel() {}
private String format;
private Options options;
private String template;
private boolean stream;
private String keepAlive;
public OllamaChatRequestModel(String model, List<OllamaChatMessage> messages) {
this.model = model;
this.messages = messages;
}
@Override
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
public boolean equals(Object o) {
if (!(o instanceof OllamaChatRequestModel)) {
return false;
}
return this.toString().equals(o.toString());
}
}

View File

@@ -0,0 +1,34 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import lombok.NonNull;
public class OllamaChatStreamObserver {
private OllamaStreamHandler streamHandler;
private List<OllamaChatResponseModel> responseParts = new ArrayList<>();
private String message;
public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) {
this.streamHandler = streamHandler;
}
public void notify(OllamaChatResponseModel currentResponsePart){
responseParts.add(currentResponsePart);
handleCurrentResponsePart(currentResponsePart);
}
protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart){
List<@NonNull String> allResponsePartsByNow = responseParts.stream().map(r -> r.getMessage().getContent()).collect(Collectors.toList());
message = String.join("", allResponsePartsByNow);
streamHandler.accept(message);
}
}

View File

@@ -0,0 +1,55 @@
package io.github.amithkoujalgi.ollama4j.core.models.generate;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
/**
* Helper class for creating {@link io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel}
* objects using the builder-pattern.
*/
public class OllamaGenerateRequestBuilder {
private OllamaGenerateRequestBuilder(String model, String prompt){
request = new OllamaGenerateRequestModel(model, prompt);
}
private OllamaGenerateRequestModel request;
public static OllamaGenerateRequestBuilder getInstance(String model){
return new OllamaGenerateRequestBuilder(model,"");
}
public OllamaGenerateRequestModel build(){
return request;
}
public OllamaGenerateRequestBuilder withPrompt(String prompt){
request.setPrompt(prompt);
return this;
}
public OllamaGenerateRequestBuilder withGetJsonResponse(){
this.request.setReturnFormatJson(true);
return this;
}
public OllamaGenerateRequestBuilder withOptions(Options options){
this.request.setOptions(options.getOptionsMap());
return this;
}
public OllamaGenerateRequestBuilder withTemplate(String template){
this.request.setTemplate(template);
return this;
}
public OllamaGenerateRequestBuilder withStreaming(){
this.request.setStream(true);
return this;
}
public OllamaGenerateRequestBuilder withKeepAlive(String keepAlive){
this.request.setKeepAlive(keepAlive);
return this;
}
}

View File

@@ -0,0 +1,46 @@
package io.github.amithkoujalgi.ollama4j.core.models.generate;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import java.util.List;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
public class OllamaGenerateRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody{
private String prompt;
private List<String> images;
private String system;
private String context;
private boolean raw;
public OllamaGenerateRequestModel() {
}
public OllamaGenerateRequestModel(String model, String prompt) {
this.model = model;
this.prompt = prompt;
}
public OllamaGenerateRequestModel(String model, String prompt, List<String> images) {
this.model = model;
this.prompt = prompt;
this.images = images;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof OllamaGenerateRequestModel)) {
return false;
}
return this.toString().equals(o.toString());
}
}

View File

@@ -1,4 +1,4 @@
package io.github.amithkoujalgi.ollama4j.core.models;
package io.github.amithkoujalgi.ollama4j.core.models.generate;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
@@ -8,7 +8,7 @@ import lombok.Data;
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaResponseModel {
public class OllamaGenerateResponseModel {
private String model;
private @JsonProperty("created_at") String createdAt;
private String response;

View File

@@ -1,12 +1,19 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
/**
@@ -16,6 +23,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
private OllamaChatStreamObserver streamObserver;
public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose);
}
@@ -27,18 +36,25 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
@Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
try {
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
responseBuffer.append(ollamaResponseModel.getMessage().getContent());
return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!",e);
return true;
}
try {
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
responseBuffer.append(ollamaResponseModel.getMessage().getContent());
if(streamObserver != null) {
streamObserver.notify(ollamaResponseModel);
}
return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!",e);
return true;
}
}
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaChatStreamObserver(streamHandler);
return super.callSync(body);
}
}

View File

@@ -46,7 +46,7 @@ public abstract class OllamaEndpointCaller {
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
/**
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
*
@@ -56,7 +56,7 @@ public abstract class OllamaEndpointCaller {
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaResult generateSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{
// Create Request
long startTime = System.currentTimeMillis();

View File

@@ -6,7 +6,7 @@ import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
@@ -25,7 +25,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
@Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
try {
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
responseBuffer.append(ollamaResponseModel.getResponse());
return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) {

View File

@@ -0,0 +1,21 @@
package io.github.amithkoujalgi.ollama4j.core.utils;
import java.io.IOException;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
public class BooleanToJsonFormatFlagSerializer extends JsonSerializer<Boolean>{
@Override
public void serialize(Boolean value, JsonGenerator gen, SerializerProvider serializers) throws IOException {
gen.writeString("json");
}
@Override
public boolean isEmpty(SerializerProvider provider,Boolean value){
return !value;
}
}

View File

@@ -1,8 +1,6 @@
package io.github.amithkoujalgi.ollama4j.core.utils;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.Base64;
import java.util.Collection;
@@ -20,11 +18,4 @@ public class FileToBase64Serializer extends JsonSerializer<Collection<byte[]>> {
}
jsonGenerator.writeEndArray();
}
public static byte[] serialize(Object obj) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
ObjectOutputStream os = new ObjectOutputStream(out);
os.writeObject(obj);
return out.toByteArray();
}
}

View File

@@ -23,8 +23,13 @@ import lombok.Data;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class TestRealAPIs {
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
OllamaAPI ollamaAPI;
Config config;
@@ -164,6 +169,31 @@ class TestRealAPIs {
}
}
@Test
@Order(3)
void testChatWithStream() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
StringBuffer sb = new StringBuffer("");
OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length()-1);
LOG.info(substring);
sb.append(substring);
});
assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
@Order(3)
void testChatWithImageFromFileWithHistoryRecognition() {

View File

@@ -0,0 +1,106 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import java.io.File;
import java.util.List;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class TestChatRequestSerialization {
private OllamaChatRequestBuilder builder;
private ObjectMapper mapper = Utils.getObjectMapper();
@BeforeEach
public void init() {
builder = OllamaChatRequestBuilder.getInstance("DummyModel");
}
@Test
public void testRequestOnlyMandatoryFields() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
}
@Test
public void testRequestMultipleMessages() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt")
.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
}
@Test
public void testRequestWithMessageAndImage() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
}
@Test
public void testRequestWithOptions() {
OptionsBuilder b = new OptionsBuilder();
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req);
OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
}
@Test
public void testWithJsonFormat() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.withGetJsonResponse().build();
String jsonRequest = serializeRequest(req);
// no jackson deserialization as format property is not boolean ==> omit as deserialization
// of request is never used in real code anyways
JSONObject jsonObject = new JSONObject(jsonRequest);
String requestFormatProperty = jsonObject.getString("format");
assertEquals("json", requestFormatProperty);
}
private String serializeRequest(OllamaChatRequestModel req) {
try {
return mapper.writeValueAsString(req);
} catch (JsonProcessingException e) {
fail("Could not serialize request!", e);
return null;
}
}
private OllamaChatRequestModel deserializeRequest(String jsonRequest) {
try {
return mapper.readValue(jsonRequest, OllamaChatRequestModel.class);
} catch (JsonProcessingException e) {
fail("Could not deserialize jsonRequest!", e);
return null;
}
}
private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest,
OllamaChatRequestModel req) {
assertEquals(req, unmarshalledRequest);
}
}

View File

@@ -0,0 +1,85 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class TestGenerateRequestSerialization {
private OllamaGenerateRequestBuilder builder;
private ObjectMapper mapper = Utils.getObjectMapper();
@BeforeEach
public void init() {
builder = OllamaGenerateRequestBuilder.getInstance("DummyModel");
}
@Test
public void testRequestOnlyMandatoryFields() {
OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
}
@Test
public void testRequestWithOptions() {
OptionsBuilder b = new OptionsBuilder();
OllamaGenerateRequestModel req =
builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req);
OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
}
@Test
public void testWithJsonFormat() {
OllamaGenerateRequestModel req =
builder.withPrompt("Some prompt").withGetJsonResponse().build();
String jsonRequest = serializeRequest(req);
// no jackson deserialization as format property is not boolean ==> omit as deserialization
// of request is never used in real code anyways
JSONObject jsonObject = new JSONObject(jsonRequest);
String requestFormatProperty = jsonObject.getString("format");
assertEquals("json", requestFormatProperty);
}
private String serializeRequest(OllamaGenerateRequestModel req) {
try {
return mapper.writeValueAsString(req);
} catch (JsonProcessingException e) {
fail("Could not serialize request!", e);
return null;
}
}
private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) {
try {
return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class);
} catch (JsonProcessingException e) {
fail("Could not deserialize jsonRequest!", e);
return null;
}
}
private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest,
OllamaGenerateRequestModel req) {
assertEquals(req, unmarshalledRequest);
}
}