mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 20:07:10 +02:00
Adds streaming feature to Generate APIs
This commit is contained in:
parent
cf4e7a96e8
commit
abb76ad867
@ -342,13 +342,24 @@ public class OllamaAPI {
|
|||||||
* @param options the Options object - <a
|
* @param options the Options object - <a
|
||||||
* href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
|
* href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
|
||||||
* details on the options</a>
|
* details on the options</a>
|
||||||
|
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
|
||||||
* @return OllamaResult that includes response text and time taken for response
|
* @return OllamaResult that includes response text and time taken for response
|
||||||
*/
|
*/
|
||||||
public OllamaResult generate(String model, String prompt, Options options)
|
public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
|
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
|
||||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||||
return generateSyncForOllamaRequestModel(ollamaRequestModel);
|
return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convenience method to call Ollama API without streaming responses.
|
||||||
|
*
|
||||||
|
* Uses {@link #generate(String, String, Options, OllamaStreamHandler)}
|
||||||
|
*/
|
||||||
|
public OllamaResult generate(String model, String prompt, Options options)
|
||||||
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
|
return generate(model, prompt, options,null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -381,10 +392,11 @@ public class OllamaAPI {
|
|||||||
* @param options the Options object - <a
|
* @param options the Options object - <a
|
||||||
* href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
|
* href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
|
||||||
* details on the options</a>
|
* details on the options</a>
|
||||||
|
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
|
||||||
* @return OllamaResult that includes response text and time taken for response
|
* @return OllamaResult that includes response text and time taken for response
|
||||||
*/
|
*/
|
||||||
public OllamaResult generateWithImageFiles(
|
public OllamaResult generateWithImageFiles(
|
||||||
String model, String prompt, List<File> imageFiles, Options options)
|
String model, String prompt, List<File> imageFiles, Options options, OllamaStreamHandler streamHandler)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
List<String> images = new ArrayList<>();
|
List<String> images = new ArrayList<>();
|
||||||
for (File imageFile : imageFiles) {
|
for (File imageFile : imageFiles) {
|
||||||
@ -392,9 +404,20 @@ public class OllamaAPI {
|
|||||||
}
|
}
|
||||||
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
|
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
|
||||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||||
return generateSyncForOllamaRequestModel(ollamaRequestModel);
|
return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convenience method to call Ollama API without streaming responses.
|
||||||
|
*
|
||||||
|
* Uses {@link #generateWithImageFiles(String, String, List, Options, OllamaStreamHandler)}
|
||||||
|
*/
|
||||||
|
public OllamaResult generateWithImageFiles(
|
||||||
|
String model, String prompt, List<File> imageFiles, Options options)
|
||||||
|
throws OllamaBaseException, IOException, InterruptedException{
|
||||||
|
return generateWithImageFiles(model, prompt, imageFiles, options, null);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* With one or more image URLs, ask a question to a model running on Ollama server. This is a
|
* With one or more image URLs, ask a question to a model running on Ollama server. This is a
|
||||||
* sync/blocking call.
|
* sync/blocking call.
|
||||||
@ -405,10 +428,11 @@ public class OllamaAPI {
|
|||||||
* @param options the Options object - <a
|
* @param options the Options object - <a
|
||||||
* href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
|
* href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
|
||||||
* details on the options</a>
|
* details on the options</a>
|
||||||
|
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
|
||||||
* @return OllamaResult that includes response text and time taken for response
|
* @return OllamaResult that includes response text and time taken for response
|
||||||
*/
|
*/
|
||||||
public OllamaResult generateWithImageURLs(
|
public OllamaResult generateWithImageURLs(
|
||||||
String model, String prompt, List<String> imageURLs, Options options)
|
String model, String prompt, List<String> imageURLs, Options options, OllamaStreamHandler streamHandler)
|
||||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||||
List<String> images = new ArrayList<>();
|
List<String> images = new ArrayList<>();
|
||||||
for (String imageURL : imageURLs) {
|
for (String imageURL : imageURLs) {
|
||||||
@ -416,7 +440,18 @@ public class OllamaAPI {
|
|||||||
}
|
}
|
||||||
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
|
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
|
||||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||||
return generateSyncForOllamaRequestModel(ollamaRequestModel);
|
return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convenience method to call Ollama API without streaming responses.
|
||||||
|
*
|
||||||
|
* Uses {@link #generateWithImageURLs(String, String, List, Options, OllamaStreamHandler)}
|
||||||
|
*/
|
||||||
|
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs,
|
||||||
|
Options options)
|
||||||
|
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||||
|
return generateWithImageURLs(model, prompt, imageURLs, options, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -487,10 +522,19 @@ public class OllamaAPI {
|
|||||||
return Base64.getEncoder().encodeToString(bytes);
|
return Base64.getEncoder().encodeToString(bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel)
|
private OllamaResult generateSyncForOllamaRequestModel(
|
||||||
|
OllamaGenerateRequestModel ollamaRequestModel, OllamaStreamHandler streamHandler)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
OllamaGenerateEndpointCaller requestCaller =
|
||||||
return requestCaller.callSync(ollamaRequestModel);
|
new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||||
|
OllamaResult result;
|
||||||
|
if (streamHandler != null) {
|
||||||
|
ollamaRequestModel.setStream(true);
|
||||||
|
result = requestCaller.call(ollamaRequestModel, streamHandler);
|
||||||
|
} else {
|
||||||
|
result = requestCaller.callSync(ollamaRequestModel);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.core.models.generate;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
|
||||||
|
|
||||||
|
public class OllamaGenerateStreamObserver {
|
||||||
|
|
||||||
|
private OllamaStreamHandler streamHandler;
|
||||||
|
|
||||||
|
private List<OllamaGenerateResponseModel> responseParts = new ArrayList<>();
|
||||||
|
|
||||||
|
private String message = "";
|
||||||
|
|
||||||
|
public OllamaGenerateStreamObserver(OllamaStreamHandler streamHandler) {
|
||||||
|
this.streamHandler = streamHandler;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void notify(OllamaGenerateResponseModel currentResponsePart){
|
||||||
|
responseParts.add(currentResponsePart);
|
||||||
|
handleCurrentResponsePart(currentResponsePart);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart){
|
||||||
|
message = message + currentResponsePart.getResponse();
|
||||||
|
streamHandler.accept(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
@ -1,18 +1,25 @@
|
|||||||
package io.github.amithkoujalgi.ollama4j.core.models.request;
|
package io.github.amithkoujalgi.ollama4j.core.models.request;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
|
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
|
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||||
|
|
||||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
|
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
|
||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
|
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
|
||||||
|
|
||||||
|
private OllamaGenerateStreamObserver streamObserver;
|
||||||
|
|
||||||
public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
||||||
super(host, basicAuth, requestTimeoutSeconds, verbose);
|
super(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||||
}
|
}
|
||||||
@ -27,6 +34,9 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
|
|||||||
try {
|
try {
|
||||||
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
|
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
|
||||||
responseBuffer.append(ollamaResponseModel.getResponse());
|
responseBuffer.append(ollamaResponseModel.getResponse());
|
||||||
|
if(streamObserver != null) {
|
||||||
|
streamObserver.notify(ollamaResponseModel);
|
||||||
|
}
|
||||||
return ollamaResponseModel.isDone();
|
return ollamaResponseModel.isDone();
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
LOG.error("Error parsing the Ollama chat response!",e);
|
LOG.error("Error parsing the Ollama chat response!",e);
|
||||||
@ -34,7 +44,11 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
||||||
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
|
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
|
||||||
|
return super.callSync(body);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -109,6 +109,32 @@ class TestRealAPIs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(3)
|
||||||
|
void testAskModelWithDefaultOptionsStreamed() {
|
||||||
|
testEndpointReachability();
|
||||||
|
try {
|
||||||
|
|
||||||
|
StringBuffer sb = new StringBuffer("");
|
||||||
|
|
||||||
|
OllamaResult result = ollamaAPI.generate(config.getModel(),
|
||||||
|
"What is the capital of France? And what's France's connection with Mona Lisa?",
|
||||||
|
new OptionsBuilder().build(), (s) -> {
|
||||||
|
LOG.info(s);
|
||||||
|
String substring = s.substring(sb.toString().length(), s.length());
|
||||||
|
LOG.info(substring);
|
||||||
|
sb.append(substring);
|
||||||
|
});
|
||||||
|
|
||||||
|
assertNotNull(result);
|
||||||
|
assertNotNull(result.getResponse());
|
||||||
|
assertFalse(result.getResponse().isEmpty());
|
||||||
|
assertEquals(sb.toString().trim(), result.getResponse().trim());
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testAskModelWithOptions() {
|
void testAskModelWithOptions() {
|
||||||
@ -262,6 +288,30 @@ class TestRealAPIs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(3)
|
||||||
|
void testAskModelWithOptionsAndImageFilesStreamed() {
|
||||||
|
testEndpointReachability();
|
||||||
|
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
|
||||||
|
try {
|
||||||
|
StringBuffer sb = new StringBuffer("");
|
||||||
|
|
||||||
|
OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
|
||||||
|
"What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
|
||||||
|
LOG.info(s);
|
||||||
|
String substring = s.substring(sb.toString().length(), s.length());
|
||||||
|
LOG.info(substring);
|
||||||
|
sb.append(substring);
|
||||||
|
});
|
||||||
|
assertNotNull(result);
|
||||||
|
assertNotNull(result.getResponse());
|
||||||
|
assertFalse(result.getResponse().isEmpty());
|
||||||
|
assertEquals(sb.toString().trim(), result.getResponse().trim());
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testAskModelWithOptionsAndImageURLs() {
|
void testAskModelWithOptionsAndImageURLs() {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user