forked from Mirror/ollama4j
Merge pull request #1 from AgentSchmecker/feature/streaming-chat
Adds streaming functionality for chat
This commit is contained in:
commit
7eb16b7ba0
@ -448,12 +448,31 @@ public class OllamaAPI {
|
||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||
*/
|
||||
public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{
|
||||
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||
//TODO: implement async way
|
||||
if(request.isStream()){
|
||||
throw new UnsupportedOperationException("Streamed chat responses are not implemented yet");
|
||||
return chat(request);
|
||||
}
|
||||
|
||||
/**
|
||||
* Ask a question to a model using an {@link OllamaChatRequestModel}. This can be constructed using an {@link OllamaChatRequestBuilder}.
|
||||
*
|
||||
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
|
||||
*
|
||||
* @param request request object to be sent to the server
|
||||
* @param streamHandler callback handler to handle the last message from stream (caution: all previous messages from stream will be concatenated)
|
||||
* @return
|
||||
* @throws OllamaBaseException any response code than 200 has been returned
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||
*/
|
||||
public OllamaChatResult chat(OllamaChatRequestModel request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException{
|
||||
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||
OllamaResult result;
|
||||
if(streamHandler != null){
|
||||
request.setStream(true);
|
||||
result = requestCaller.call(request, streamHandler);
|
||||
}
|
||||
else {
|
||||
result = requestCaller.callSync(request);
|
||||
}
|
||||
OllamaResult result = requestCaller.generateSync(request);
|
||||
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
|
||||
}
|
||||
|
||||
@ -470,7 +489,7 @@ public class OllamaAPI {
|
||||
private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||
return requestCaller.generateSync(ollamaRequestModel);
|
||||
return requestCaller.callSync(ollamaRequestModel);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -0,0 +1,7 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core;
|
||||
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public interface OllamaStreamHandler extends Consumer<String>{
|
||||
void accept(String message);
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.models.chat;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
|
||||
import lombok.NonNull;
|
||||
|
||||
public class OllamaChatStreamObserver {
|
||||
|
||||
private OllamaStreamHandler streamHandler;
|
||||
|
||||
private List<OllamaChatResponseModel> responseParts = new ArrayList<>();
|
||||
|
||||
private String message;
|
||||
|
||||
public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) {
|
||||
this.streamHandler = streamHandler;
|
||||
}
|
||||
|
||||
public void notify(OllamaChatResponseModel currentResponsePart){
|
||||
responseParts.add(currentResponsePart);
|
||||
handleCurrentResponsePart(currentResponsePart);
|
||||
}
|
||||
|
||||
protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart){
|
||||
List<@NonNull String> allResponsePartsByNow = responseParts.stream().map(r -> r.getMessage().getContent()).collect(Collectors.toList());
|
||||
message = String.join("", allResponsePartsByNow);
|
||||
streamHandler.accept(message);
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -1,12 +1,19 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.models.request;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
|
||||
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
|
||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatStreamObserver;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||
|
||||
/**
|
||||
@ -16,6 +23,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
|
||||
|
||||
private OllamaChatStreamObserver streamObserver;
|
||||
|
||||
public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
||||
super(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||
}
|
||||
@ -30,6 +39,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
|
||||
try {
|
||||
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
|
||||
responseBuffer.append(ollamaResponseModel.getMessage().getContent());
|
||||
if(streamObserver != null) {
|
||||
streamObserver.notify(ollamaResponseModel);
|
||||
}
|
||||
return ollamaResponseModel.isDone();
|
||||
} catch (JsonProcessingException e) {
|
||||
LOG.error("Error parsing the Ollama chat response!",e);
|
||||
@ -37,7 +49,11 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
streamObserver = new OllamaChatStreamObserver(streamHandler);
|
||||
return super.callSync(body);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
@ -56,7 +56,7 @@ public abstract class OllamaEndpointCaller {
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||
*/
|
||||
public OllamaResult generateSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{
|
||||
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{
|
||||
|
||||
// Create Request
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
@ -23,8 +23,13 @@ import lombok.Data;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Order;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
class TestRealAPIs {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
|
||||
|
||||
OllamaAPI ollamaAPI;
|
||||
Config config;
|
||||
|
||||
@ -164,6 +169,31 @@ class TestRealAPIs {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void testChatWithStream() {
|
||||
testEndpointReachability();
|
||||
try {
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?")
|
||||
.build();
|
||||
|
||||
StringBuffer sb = new StringBuffer("");
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> {
|
||||
LOG.info(s);
|
||||
String substring = s.substring(sb.toString().length(), s.length()-1);
|
||||
LOG.info(substring);
|
||||
sb.append(substring);
|
||||
});
|
||||
assertNotNull(chatResult);
|
||||
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void testChatWithImageFromFileWithHistoryRecognition() {
|
||||
|
Loading…
x
Reference in New Issue
Block a user