Adds streaming feature to Generate APIs

This commit is contained in:
Markus Klenke 2024-02-16 17:03:15 +00:00
parent cf4e7a96e8
commit abb76ad867
4 changed files with 150 additions and 11 deletions

View File

@ -342,13 +342,24 @@ public class OllamaAPI {
* @param options the Options object - <a * @param options the Options object - <a
* href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More * href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
* details on the options</a> * details on the options</a>
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
* @return OllamaResult that includes response text and time taken for response * @return OllamaResult that includes response text and time taken for response
*/ */
public OllamaResult generate(String model, String prompt, Options options) public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel); return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler);
}
/**
* Convenience method to call Ollama API without streaming responses.
*
* Uses {@link #generate(String, String, Options, OllamaStreamHandler)}
*/
public OllamaResult generate(String model, String prompt, Options options)
throws OllamaBaseException, IOException, InterruptedException {
return generate(model, prompt, options,null);
} }
/** /**
@ -381,10 +392,11 @@ public class OllamaAPI {
* @param options the Options object - <a * @param options the Options object - <a
* href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More * href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
* details on the options</a> * details on the options</a>
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
* @return OllamaResult that includes response text and time taken for response * @return OllamaResult that includes response text and time taken for response
*/ */
public OllamaResult generateWithImageFiles( public OllamaResult generateWithImageFiles(
String model, String prompt, List<File> imageFiles, Options options) String model, String prompt, List<File> imageFiles, Options options, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
List<String> images = new ArrayList<>(); List<String> images = new ArrayList<>();
for (File imageFile : imageFiles) { for (File imageFile : imageFiles) {
@ -392,7 +404,18 @@ public class OllamaAPI {
} }
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images); OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel); return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler);
}
/**
* Convenience method to call Ollama API without streaming responses.
*
* Uses {@link #generateWithImageFiles(String, String, List, Options, OllamaStreamHandler)}
*/
public OllamaResult generateWithImageFiles(
String model, String prompt, List<File> imageFiles, Options options)
throws OllamaBaseException, IOException, InterruptedException{
return generateWithImageFiles(model, prompt, imageFiles, options, null);
} }
/** /**
@ -405,10 +428,11 @@ public class OllamaAPI {
* @param options the Options object - <a * @param options the Options object - <a
* href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More * href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
* details on the options</a> * details on the options</a>
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
* @return OllamaResult that includes response text and time taken for response * @return OllamaResult that includes response text and time taken for response
*/ */
public OllamaResult generateWithImageURLs( public OllamaResult generateWithImageURLs(
String model, String prompt, List<String> imageURLs, Options options) String model, String prompt, List<String> imageURLs, Options options, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
List<String> images = new ArrayList<>(); List<String> images = new ArrayList<>();
for (String imageURL : imageURLs) { for (String imageURL : imageURLs) {
@ -416,7 +440,18 @@ public class OllamaAPI {
} }
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images); OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel); return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler);
}
/**
* Convenience method to call Ollama API without streaming responses.
*
* Uses {@link #generateWithImageURLs(String, String, List, Options, OllamaStreamHandler)}
*/
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs,
Options options)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
return generateWithImageURLs(model, prompt, imageURLs, options, null);
} }
@ -487,10 +522,19 @@ public class OllamaAPI {
return Base64.getEncoder().encodeToString(bytes); return Base64.getEncoder().encodeToString(bytes);
} }
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel) private OllamaResult generateSyncForOllamaRequestModel(
OllamaGenerateRequestModel ollamaRequestModel, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaGenerateEndpointCaller requestCaller =
return requestCaller.callSync(ollamaRequestModel); new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
OllamaResult result;
if (streamHandler != null) {
ollamaRequestModel.setStream(true);
result = requestCaller.call(ollamaRequestModel, streamHandler);
} else {
result = requestCaller.callSync(ollamaRequestModel);
}
return result;
} }
/** /**

View File

@ -0,0 +1,31 @@
package io.github.amithkoujalgi.ollama4j.core.models.generate;
import java.util.ArrayList;
import java.util.List;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
public class OllamaGenerateStreamObserver {
private OllamaStreamHandler streamHandler;
private List<OllamaGenerateResponseModel> responseParts = new ArrayList<>();
private String message = "";
public OllamaGenerateStreamObserver(OllamaStreamHandler streamHandler) {
this.streamHandler = streamHandler;
}
public void notify(OllamaGenerateResponseModel currentResponsePart){
responseParts.add(currentResponsePart);
handleCurrentResponsePart(currentResponsePart);
}
protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart){
message = message + currentResponsePart.getResponse();
streamHandler.accept(message);
}
}

View File

@ -1,18 +1,25 @@
package io.github.amithkoujalgi.ollama4j.core.models.request; package io.github.amithkoujalgi.ollama4j.core.models.request;
import java.io.IOException;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
private OllamaGenerateStreamObserver streamObserver;
public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose); super(host, basicAuth, requestTimeoutSeconds, verbose);
} }
@ -27,6 +34,9 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
try { try {
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
responseBuffer.append(ollamaResponseModel.getResponse()); responseBuffer.append(ollamaResponseModel.getResponse());
if(streamObserver != null) {
streamObserver.notify(ollamaResponseModel);
}
return ollamaResponseModel.isDone(); return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!",e); LOG.error("Error parsing the Ollama chat response!",e);
@ -34,7 +44,11 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
} }
} }
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
return super.callSync(body);
}
} }

View File

@ -109,6 +109,32 @@ class TestRealAPIs {
} }
} }
@Test
@Order(3)
void testAskModelWithDefaultOptionsStreamed() {
testEndpointReachability();
try {
StringBuffer sb = new StringBuffer("");
OllamaResult result = ollamaAPI.generate(config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test @Test
@Order(3) @Order(3)
void testAskModelWithOptions() { void testAskModelWithOptions() {
@ -262,6 +288,30 @@ class TestRealAPIs {
} }
} }
@Test
@Order(3)
void testAskModelWithOptionsAndImageFilesStreamed() {
testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
try {
StringBuffer sb = new StringBuffer("");
OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
"What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test @Test
@Order(3) @Order(3)
void testAskModelWithOptionsAndImageURLs() { void testAskModelWithOptionsAndImageURLs() {