Merge pull request #26 from AgentSchmecker/main

Adds streaming functionality for chat
This commit is contained in:
Amith Koujalgi 2024-02-16 10:11:32 +05:30 committed by GitHub
commit a9e7958d44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 158 additions and 17 deletions

View File

@ -69,6 +69,41 @@ You will get a response similar to:
} ] } ]
``` ```
## Create a conversation where the answer is streamed
```java
public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
// define a handler (Consumer<String>)
OllamaStreamHandler streamHandler = (s) -> {
System.out.println(s);
};
OllamaChatResult chatResult = ollamaAPI.chat(requestModel,streamHandler);
}
}
```
You will get a response similar to:
> The
> The capital
> The capital of
> The capital of France
> The capital of France is
> The capital of France is Paris
> The capital of France is Paris.
## Create a new conversation with individual system prompt ## Create a new conversation with individual system prompt
```java ```java
public class Main { public class Main {

View File

@ -448,12 +448,31 @@ public class OllamaAPI {
* @throws InterruptedException in case the server is not reachable or network issues happen * @throws InterruptedException in case the server is not reachable or network issues happen
*/ */
public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{ public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{
return chat(request);
}
/**
* Ask a question to a model using an {@link OllamaChatRequestModel}. This can be constructed using an {@link OllamaChatRequestBuilder}.
*
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
*
* @param request request object to be sent to the server
* @param streamHandler callback handler to handle the last message from stream (caution: all previous messages from stream will be concatenated)
* @return
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaChatResult chat(OllamaChatRequestModel request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException{
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
//TODO: implement async way OllamaResult result;
if(request.isStream()){ if(streamHandler != null){
throw new UnsupportedOperationException("Streamed chat responses are not implemented yet"); request.setStream(true);
result = requestCaller.call(request, streamHandler);
}
else {
result = requestCaller.callSync(request);
} }
OllamaResult result = requestCaller.generateSync(request);
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
} }
@ -470,7 +489,7 @@ public class OllamaAPI {
private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel) private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
return requestCaller.generateSync(ollamaRequestModel); return requestCaller.callSync(ollamaRequestModel);
} }
/** /**

View File

@ -0,0 +1,7 @@
package io.github.amithkoujalgi.ollama4j.core;
import java.util.function.Consumer;
public interface OllamaStreamHandler extends Consumer<String>{
void accept(String message);
}

View File

@ -0,0 +1,34 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import lombok.NonNull;
public class OllamaChatStreamObserver {
private OllamaStreamHandler streamHandler;
private List<OllamaChatResponseModel> responseParts = new ArrayList<>();
private String message;
public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) {
this.streamHandler = streamHandler;
}
public void notify(OllamaChatResponseModel currentResponsePart){
responseParts.add(currentResponsePart);
handleCurrentResponsePart(currentResponsePart);
}
protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart){
List<@NonNull String> allResponsePartsByNow = responseParts.stream().map(r -> r.getMessage().getContent()).collect(Collectors.toList());
message = String.join("", allResponsePartsByNow);
streamHandler.accept(message);
}
}

View File

@ -1,12 +1,19 @@
package io.github.amithkoujalgi.ollama4j.core.models.request; package io.github.amithkoujalgi.ollama4j.core.models.request;
import java.io.IOException;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
/** /**
@ -16,6 +23,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
private OllamaChatStreamObserver streamObserver;
public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose); super(host, basicAuth, requestTimeoutSeconds, verbose);
} }
@ -27,18 +36,25 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
@Override @Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
try { try {
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
responseBuffer.append(ollamaResponseModel.getMessage().getContent()); responseBuffer.append(ollamaResponseModel.getMessage().getContent());
return ollamaResponseModel.isDone(); if(streamObserver != null) {
} catch (JsonProcessingException e) { streamObserver.notify(ollamaResponseModel);
LOG.error("Error parsing the Ollama chat response!",e); }
return true; return ollamaResponseModel.isDone();
} } catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!",e);
return true;
}
} }
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaChatStreamObserver(streamHandler);
return super.callSync(body);
}
} }

View File

@ -46,7 +46,7 @@ public abstract class OllamaEndpointCaller {
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
/** /**
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response. * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
* *
@ -56,7 +56,7 @@ public abstract class OllamaEndpointCaller {
* @throws IOException in case the responseStream can not be read * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen * @throws InterruptedException in case the server is not reachable or network issues happen
*/ */
public OllamaResult generateSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{ public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{
// Create Request // Create Request
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();

View File

@ -23,8 +23,13 @@ import lombok.Data;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class TestRealAPIs { class TestRealAPIs {
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
OllamaAPI ollamaAPI; OllamaAPI ollamaAPI;
Config config; Config config;
@ -164,6 +169,31 @@ class TestRealAPIs {
} }
} }
@Test
@Order(3)
void testChatWithStream() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
StringBuffer sb = new StringBuffer("");
OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length()-1);
LOG.info(substring);
sb.append(substring);
});
assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test @Test
@Order(3) @Order(3)
void testChatWithImageFromFileWithHistoryRecognition() { void testChatWithImageFromFileWithHistoryRecognition() {