Compare commits

..

8 Commits

Author SHA1 Message Date
amithkoujalgi
713a3239a4 [maven-release-plugin] prepare release v1.0.53 2024-02-16 04:42:40 +00:00
Amith Koujalgi
a9e7958d44 Merge pull request #26 from AgentSchmecker/main
Adds streaming functionality for chat
2024-02-16 10:11:32 +05:30
Markus Klenke
f38e84053f Adds documentation for streamed chat API call 2024-02-14 16:45:46 +00:00
Markus Klenke
7eb16b7ba0 Merge pull request #1 from AgentSchmecker/feature/streaming-chat
Adds streaming functionality for chat
2024-02-14 16:30:51 +01:00
amithkoujalgi
5a3889d8ee [maven-release-plugin] prepare for next development iteration 2024-02-14 13:44:30 +00:00
Markus Klenke
e9621f054d Adds integration test for chat streaming API 2024-02-13 18:11:59 +00:00
Markus Klenke
b41b62220c Adds chat with stream functionality in OllamaAPI 2024-02-13 17:59:27 +00:00
Markus Klenke
c89440cbca Adds OllamaStream handling 2024-02-13 17:56:07 +00:00
8 changed files with 160 additions and 19 deletions

View File

@@ -69,6 +69,41 @@ You will get a response similar to:
} ] } ]
``` ```
## Create a conversation where the answer is streamed
```java
public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
// define a handler (Consumer<String>)
OllamaStreamHandler streamHandler = (s) -> {
System.out.println(s);
};
OllamaChatResult chatResult = ollamaAPI.chat(requestModel,streamHandler);
}
}
```
You will get a response similar to:
> The
> The capital
> The capital of
> The capital of France
> The capital of France is
> The capital of France is Paris
> The capital of France is Paris.
## Create a new conversation with individual system prompt ## Create a new conversation with individual system prompt
```java ```java
public class Main { public class Main {

View File

@@ -4,7 +4,7 @@
<groupId>io.github.amithkoujalgi</groupId> <groupId>io.github.amithkoujalgi</groupId>
<artifactId>ollama4j</artifactId> <artifactId>ollama4j</artifactId>
<version>1.0.52</version> <version>1.0.53</version>
<name>Ollama4j</name> <name>Ollama4j</name>
<description>Java library for interacting with Ollama API.</description> <description>Java library for interacting with Ollama API.</description>
@@ -39,7 +39,7 @@
<connection>scm:git:git@github.com:amithkoujalgi/ollama4j.git</connection> <connection>scm:git:git@github.com:amithkoujalgi/ollama4j.git</connection>
<developerConnection>scm:git:https://github.com/amithkoujalgi/ollama4j.git</developerConnection> <developerConnection>scm:git:https://github.com/amithkoujalgi/ollama4j.git</developerConnection>
<url>https://github.com/amithkoujalgi/ollama4j</url> <url>https://github.com/amithkoujalgi/ollama4j</url>
<tag>v1.0.52</tag> <tag>v1.0.53</tag>
</scm> </scm>
<build> <build>

View File

@@ -448,12 +448,31 @@ public class OllamaAPI {
* @throws InterruptedException in case the server is not reachable or network issues happen * @throws InterruptedException in case the server is not reachable or network issues happen
*/ */
public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{ public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{
return chat(request);
}
/**
* Ask a question to a model using an {@link OllamaChatRequestModel}. This can be constructed using an {@link OllamaChatRequestBuilder}.
*
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
*
* @param request request object to be sent to the server
* @param streamHandler callback handler to handle the last message from stream (caution: all previous messages from stream will be concatenated)
* @return
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaChatResult chat(OllamaChatRequestModel request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException{
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
//TODO: implement async way OllamaResult result;
if(request.isStream()){ if(streamHandler != null){
throw new UnsupportedOperationException("Streamed chat responses are not implemented yet"); request.setStream(true);
result = requestCaller.call(request, streamHandler);
}
else {
result = requestCaller.callSync(request);
} }
OllamaResult result = requestCaller.generateSync(request);
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
} }
@@ -470,7 +489,7 @@ public class OllamaAPI {
private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel) private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
return requestCaller.generateSync(ollamaRequestModel); return requestCaller.callSync(ollamaRequestModel);
} }
/** /**

View File

@@ -0,0 +1,7 @@
package io.github.amithkoujalgi.ollama4j.core;
import java.util.function.Consumer;
public interface OllamaStreamHandler extends Consumer<String>{
void accept(String message);
}

View File

@@ -0,0 +1,34 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import lombok.NonNull;
public class OllamaChatStreamObserver {
private OllamaStreamHandler streamHandler;
private List<OllamaChatResponseModel> responseParts = new ArrayList<>();
private String message;
public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) {
this.streamHandler = streamHandler;
}
public void notify(OllamaChatResponseModel currentResponsePart){
responseParts.add(currentResponsePart);
handleCurrentResponsePart(currentResponsePart);
}
protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart){
List<@NonNull String> allResponsePartsByNow = responseParts.stream().map(r -> r.getMessage().getContent()).collect(Collectors.toList());
message = String.join("", allResponsePartsByNow);
streamHandler.accept(message);
}
}

View File

@@ -1,12 +1,19 @@
package io.github.amithkoujalgi.ollama4j.core.models.request; package io.github.amithkoujalgi.ollama4j.core.models.request;
import java.io.IOException;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
/** /**
@@ -16,6 +23,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
private OllamaChatStreamObserver streamObserver;
public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose); super(host, basicAuth, requestTimeoutSeconds, verbose);
} }
@@ -27,18 +36,25 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
@Override @Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
try { try {
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
responseBuffer.append(ollamaResponseModel.getMessage().getContent()); responseBuffer.append(ollamaResponseModel.getMessage().getContent());
return ollamaResponseModel.isDone(); if(streamObserver != null) {
} catch (JsonProcessingException e) { streamObserver.notify(ollamaResponseModel);
LOG.error("Error parsing the Ollama chat response!",e); }
return true; return ollamaResponseModel.isDone();
} } catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!",e);
return true;
}
} }
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaChatStreamObserver(streamHandler);
return super.callSync(body);
}
} }

View File

@@ -46,7 +46,7 @@ public abstract class OllamaEndpointCaller {
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
/** /**
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response. * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
* *
@@ -56,7 +56,7 @@ public abstract class OllamaEndpointCaller {
* @throws IOException in case the responseStream can not be read * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen * @throws InterruptedException in case the server is not reachable or network issues happen
*/ */
public OllamaResult generateSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{ public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{
// Create Request // Create Request
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();

View File

@@ -23,8 +23,13 @@ import lombok.Data;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class TestRealAPIs { class TestRealAPIs {
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
OllamaAPI ollamaAPI; OllamaAPI ollamaAPI;
Config config; Config config;
@@ -164,6 +169,31 @@ class TestRealAPIs {
} }
} }
@Test
@Order(3)
void testChatWithStream() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
StringBuffer sb = new StringBuffer("");
OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length()-1);
LOG.info(substring);
sb.append(substring);
});
assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test @Test
@Order(3) @Order(3)
void testChatWithImageFromFileWithHistoryRecognition() { void testChatWithImageFromFileWithHistoryRecognition() {