mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 11:57:12 +02:00
Merge pull request #88 from seeseemelk/feature/token-streamer
Add ability to stream tokens in chat
This commit is contained in:
commit
dda807d818
@ -11,6 +11,7 @@ import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
|||||||
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
|
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
|
||||||
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
|
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
|
||||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||||
|
import io.github.ollama4j.models.generate.OllamaTokenHandler;
|
||||||
import io.github.ollama4j.models.ps.ModelsProcessResponse;
|
import io.github.ollama4j.models.ps.ModelsProcessResponse;
|
||||||
import io.github.ollama4j.models.request.*;
|
import io.github.ollama4j.models.request.*;
|
||||||
import io.github.ollama4j.models.response.*;
|
import io.github.ollama4j.models.response.*;
|
||||||
@ -785,15 +786,34 @@ public class OllamaAPI {
|
|||||||
* @throws InterruptedException if the operation is interrupted
|
* @throws InterruptedException if the operation is interrupted
|
||||||
*/
|
*/
|
||||||
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||||
|
return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ask a question to a model using an {@link OllamaChatRequest}. This can be constructed using an {@link OllamaChatRequestBuilder}.
|
||||||
|
* <p>
|
||||||
|
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
|
||||||
|
*
|
||||||
|
* @param request request object to be sent to the server
|
||||||
|
* @param tokenHandler callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated)
|
||||||
|
* @return {@link OllamaChatResult}
|
||||||
|
* @throws OllamaBaseException any response code than 200 has been returned
|
||||||
|
* @throws IOException in case the responseStream can not be read
|
||||||
|
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||||
|
* @throws OllamaBaseException if the response indicates an error status
|
||||||
|
* @throws IOException if an I/O error occurs during the HTTP request
|
||||||
|
* @throws InterruptedException if the operation is interrupted
|
||||||
|
*/
|
||||||
|
public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||||
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||||
OllamaChatResult result;
|
OllamaChatResult result;
|
||||||
|
|
||||||
// add all registered tools to Request
|
// add all registered tools to Request
|
||||||
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
|
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
|
||||||
|
|
||||||
if (streamHandler != null) {
|
if (tokenHandler != null) {
|
||||||
request.setStream(true);
|
request.setStream(true);
|
||||||
result = requestCaller.call(request, streamHandler);
|
result = requestCaller.call(request, tokenHandler);
|
||||||
} else {
|
} else {
|
||||||
result = requestCaller.callSync(request);
|
result = requestCaller.callSync(request);
|
||||||
}
|
}
|
||||||
@ -810,8 +830,8 @@ public class OllamaAPI {
|
|||||||
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
|
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (streamHandler != null) {
|
if (tokenHandler != null) {
|
||||||
result = requestCaller.call(request, streamHandler);
|
result = requestCaller.call(request, tokenHandler);
|
||||||
} else {
|
} else {
|
||||||
result = requestCaller.callSync(request);
|
result = requestCaller.callSync(request);
|
||||||
}
|
}
|
||||||
|
@ -1,31 +1,19 @@
|
|||||||
package io.github.ollama4j.models.chat;
|
package io.github.ollama4j.models.chat;
|
||||||
|
|
||||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||||
|
import io.github.ollama4j.models.generate.OllamaTokenHandler;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
@RequiredArgsConstructor
|
||||||
import java.util.List;
|
public class OllamaChatStreamObserver implements OllamaTokenHandler {
|
||||||
|
private final OllamaStreamHandler streamHandler;
|
||||||
public class OllamaChatStreamObserver {
|
|
||||||
|
|
||||||
private OllamaStreamHandler streamHandler;
|
|
||||||
|
|
||||||
private List<OllamaChatResponseModel> responseParts = new ArrayList<>();
|
|
||||||
|
|
||||||
private String message = "";
|
private String message = "";
|
||||||
|
|
||||||
public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) {
|
@Override
|
||||||
this.streamHandler = streamHandler;
|
public void accept(OllamaChatResponseModel token) {
|
||||||
}
|
if (streamHandler != null) {
|
||||||
|
message += token.getMessage().getContent();
|
||||||
public void notify(OllamaChatResponseModel currentResponsePart) {
|
|
||||||
responseParts.add(currentResponsePart);
|
|
||||||
handleCurrentResponsePart(currentResponsePart);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart) {
|
|
||||||
message = message + currentResponsePart.getMessage().getContent();
|
|
||||||
streamHandler.accept(message);
|
streamHandler.accept(message);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,8 @@
|
|||||||
|
package io.github.ollama4j.models.generate;
|
||||||
|
|
||||||
|
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
|
||||||
|
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
public interface OllamaTokenHandler extends Consumer<OllamaChatResponseModel> {
|
||||||
|
}
|
@ -4,9 +4,8 @@ import com.fasterxml.jackson.core.JsonProcessingException;
|
|||||||
import com.fasterxml.jackson.core.type.TypeReference;
|
import com.fasterxml.jackson.core.type.TypeReference;
|
||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||||
import io.github.ollama4j.models.chat.*;
|
import io.github.ollama4j.models.chat.*;
|
||||||
|
import io.github.ollama4j.models.generate.OllamaTokenHandler;
|
||||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
||||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
|
||||||
import io.github.ollama4j.tools.Tools;
|
|
||||||
import io.github.ollama4j.utils.Utils;
|
import io.github.ollama4j.utils.Utils;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
@ -29,7 +28,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
|||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
|
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
|
||||||
|
|
||||||
private OllamaChatStreamObserver streamObserver;
|
private OllamaTokenHandler tokenHandler;
|
||||||
|
|
||||||
public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
||||||
super(host, basicAuth, requestTimeoutSeconds, verbose);
|
super(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||||
@ -60,8 +59,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
|||||||
OllamaChatMessage message = ollamaResponseModel.getMessage();
|
OllamaChatMessage message = ollamaResponseModel.getMessage();
|
||||||
if(message != null) {
|
if(message != null) {
|
||||||
responseBuffer.append(message.getContent());
|
responseBuffer.append(message.getContent());
|
||||||
if (streamObserver != null) {
|
if (tokenHandler != null) {
|
||||||
streamObserver.notify(ollamaResponseModel);
|
tokenHandler.accept(ollamaResponseModel);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ollamaResponseModel.isDone();
|
return ollamaResponseModel.isDone();
|
||||||
@ -71,9 +70,9 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler)
|
public OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
streamObserver = new OllamaChatStreamObserver(streamHandler);
|
this.tokenHandler = tokenHandler;
|
||||||
return callSync(body);
|
return callSync(body);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,7 +85,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
|||||||
.POST(
|
.POST(
|
||||||
body.getBodyPublisher());
|
body.getBodyPublisher());
|
||||||
HttpRequest request = requestBuilder.build();
|
HttpRequest request = requestBuilder.build();
|
||||||
if (isVerbose()) LOG.info("Asking model: " + body.toString());
|
if (isVerbose()) LOG.info("Asking model: " + body);
|
||||||
HttpResponse<InputStream> response =
|
HttpResponse<InputStream> response =
|
||||||
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||||
|
|
||||||
|
@ -321,7 +321,7 @@ class TestRealAPIs {
|
|||||||
assertEquals(1, function.getArguments().size());
|
assertEquals(1, function.getArguments().size());
|
||||||
Object noOfDigits = function.getArguments().get("noOfDigits");
|
Object noOfDigits = function.getArguments().get("noOfDigits");
|
||||||
assertNotNull(noOfDigits);
|
assertNotNull(noOfDigits);
|
||||||
assertEquals("5",noOfDigits);
|
assertEquals("5", noOfDigits.toString());
|
||||||
assertTrue(chatResult.getChatHistory().size()>2);
|
assertTrue(chatResult.getChatHistory().size()>2);
|
||||||
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||||
assertNull(finalToolCalls);
|
assertNull(finalToolCalls);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user