Updated withMessages method of OllamaChatRequestBuilder to reset the messages

This commit is contained in:
Amith Koujalgi 2024-08-09 01:30:47 +05:30
parent 11a98a72a1
commit 3aa0fc77cb
2 changed files with 55 additions and 32 deletions

View File

@ -82,6 +82,33 @@ You will get a response similar to:
] ]
``` ```
## Conversational loop
```java
public class Main {
public static void main(String[] args) {
OllamaAPI ollamaAPI = new OllamaAPI();
ollamaAPI.setRequestTimeoutSeconds(60);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("<your-model>");
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "<your-first-message>").build();
OllamaChatResult initialChatResult = ollamaAPI.chat(requestModel);
System.out.println(initialChatResult.getResponse());
List<OllamaChatMessage> history = initialChatResult.getChatHistory();
while (true) {
OllamaChatResult chatResult = ollamaAPI.chat(builder.withMessages(history).withMessage(OllamaChatMessageRole.USER, "<your-new-message").build());
System.out.println(chatResult.getResponse());
history = chatResult.getChatHistory();
}
}
}
```
## Create a conversation where the answer is streamed ## Create a conversation where the answer is streamed
```java ```java

View File

@ -1,5 +1,10 @@
package io.github.ollama4j.models.chat; package io.github.ollama4j.models.chat;
import io.github.ollama4j.utils.Options;
import io.github.ollama4j.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
@ -8,12 +13,6 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.github.ollama4j.utils.Options;
import io.github.ollama4j.utils.Utils;
/** /**
* Helper class for creating {@link OllamaChatRequest} objects using the builder-pattern. * Helper class for creating {@link OllamaChatRequest} objects using the builder-pattern.
*/ */
@ -21,88 +20,85 @@ public class OllamaChatRequestBuilder {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class);
private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages){ private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages) {
request = new OllamaChatRequest(model, messages); request = new OllamaChatRequest(model, messages);
} }
private OllamaChatRequest request; private OllamaChatRequest request;
public static OllamaChatRequestBuilder getInstance(String model){ public static OllamaChatRequestBuilder getInstance(String model) {
return new OllamaChatRequestBuilder(model, new ArrayList<>()); return new OllamaChatRequestBuilder(model, new ArrayList<>());
} }
public OllamaChatRequest build(){ public OllamaChatRequest build() {
return request; return request;
} }
public void reset(){ public void reset() {
request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images){ public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images) {
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = images.stream().map(file -> { List<byte[]> binaryImages = images.stream().map(file -> {
try { try {
return Files.readAllBytes(file.toPath()); return Files.readAllBytes(file.toPath());
} catch (IOException e) { } catch (IOException e) {
LOG.warn(String.format("File '%s' could not be accessed, will not add to message!",file.toPath()), e); LOG.warn(String.format("File '%s' could not be accessed, will not add to message!", file.toPath()), e);
return new byte[0]; return new byte[0];
} }
}).collect(Collectors.toList()); }).collect(Collectors.toList());
messages.add(new OllamaChatMessage(role,content,binaryImages)); messages.add(new OllamaChatMessage(role, content, binaryImages));
return this; return this;
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls){ public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) {
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = null; List<byte[]> binaryImages = null;
if(imageUrls.length>0){ if (imageUrls.length > 0) {
binaryImages = new ArrayList<>(); binaryImages = new ArrayList<>();
for (String imageUrl : imageUrls) { for (String imageUrl : imageUrls) {
try{ try {
binaryImages.add(Utils.loadImageBytesFromUrl(imageUrl)); binaryImages.add(Utils.loadImageBytesFromUrl(imageUrl));
} } catch (URISyntaxException e) {
catch (URISyntaxException e){ LOG.warn(String.format("URL '%s' could not be accessed, will not add to message!", imageUrl), e);
LOG.warn(String.format("URL '%s' could not be accessed, will not add to message!",imageUrl), e); } catch (IOException e) {
} LOG.warn(String.format("Content of URL '%s' could not be read, will not add to message!", imageUrl), e);
catch (IOException e){
LOG.warn(String.format("Content of URL '%s' could not be read, will not add to message!",imageUrl), e);
} }
} }
} }
messages.add(new OllamaChatMessage(role,content,binaryImages)); messages.add(new OllamaChatMessage(role, content, binaryImages));
return this; return this;
} }
public OllamaChatRequestBuilder withMessages(List<OllamaChatMessage> messages){ public OllamaChatRequestBuilder withMessages(List<OllamaChatMessage> messages) {
this.request.getMessages().addAll(messages); return new OllamaChatRequestBuilder(request.getModel(), messages);
return this;
} }
public OllamaChatRequestBuilder withOptions(Options options){ public OllamaChatRequestBuilder withOptions(Options options) {
this.request.setOptions(options.getOptionsMap()); this.request.setOptions(options.getOptionsMap());
return this; return this;
} }
public OllamaChatRequestBuilder withGetJsonResponse(){ public OllamaChatRequestBuilder withGetJsonResponse() {
this.request.setReturnFormatJson(true); this.request.setReturnFormatJson(true);
return this; return this;
} }
public OllamaChatRequestBuilder withTemplate(String template){ public OllamaChatRequestBuilder withTemplate(String template) {
this.request.setTemplate(template); this.request.setTemplate(template);
return this; return this;
} }
public OllamaChatRequestBuilder withStreaming(){ public OllamaChatRequestBuilder withStreaming() {
this.request.setStream(true); this.request.setStream(true);
return this; return this;
} }
public OllamaChatRequestBuilder withKeepAlive(String keepAlive){ public OllamaChatRequestBuilder withKeepAlive(String keepAlive) {
this.request.setKeepAlive(keepAlive); this.request.setKeepAlive(keepAlive);
return this; return this;
} }