diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
index 20e9d3e..ec772f1 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
@@ -342,13 +342,24 @@ public class OllamaAPI {
* @param options the Options object - More
* details on the options
+ * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
* @return OllamaResult that includes response text and time taken for response
*/
- public OllamaResult generate(String model, String prompt, Options options)
+ public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
ollamaRequestModel.setOptions(options.getOptionsMap());
- return generateSyncForOllamaRequestModel(ollamaRequestModel);
+ return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler);
+ }
+
+ /**
+ * Convenience method to call Ollama API without streaming responses.
+ *
+ * Uses {@link #generate(String, String, Options, OllamaStreamHandler)}
+ */
+ public OllamaResult generate(String model, String prompt, Options options)
+ throws OllamaBaseException, IOException, InterruptedException {
+ return generate(model, prompt, options,null);
}
/**
@@ -381,10 +392,11 @@ public class OllamaAPI {
* @param options the Options object - More
* details on the options
+ * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
* @return OllamaResult that includes response text and time taken for response
*/
public OllamaResult generateWithImageFiles(
- String model, String prompt, List imageFiles, Options options)
+ String model, String prompt, List imageFiles, Options options, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException {
List images = new ArrayList<>();
for (File imageFile : imageFiles) {
@@ -392,9 +404,20 @@ public class OllamaAPI {
}
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap());
- return generateSyncForOllamaRequestModel(ollamaRequestModel);
+ return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler);
}
+ /**
+ * Convenience method to call Ollama API without streaming responses.
+ *
+ * Uses {@link #generateWithImageFiles(String, String, List, Options, OllamaStreamHandler)}
+ */
+ public OllamaResult generateWithImageFiles(
+ String model, String prompt, List imageFiles, Options options)
+ throws OllamaBaseException, IOException, InterruptedException{
+ return generateWithImageFiles(model, prompt, imageFiles, options, null);
+}
+
/**
* With one or more image URLs, ask a question to a model running on Ollama server. This is a
* sync/blocking call.
@@ -405,10 +428,11 @@ public class OllamaAPI {
* @param options the Options object - More
* details on the options
+ * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
* @return OllamaResult that includes response text and time taken for response
*/
public OllamaResult generateWithImageURLs(
- String model, String prompt, List imageURLs, Options options)
+ String model, String prompt, List imageURLs, Options options, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
List images = new ArrayList<>();
for (String imageURL : imageURLs) {
@@ -416,7 +440,18 @@ public class OllamaAPI {
}
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap());
- return generateSyncForOllamaRequestModel(ollamaRequestModel);
+ return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler);
+ }
+
+ /**
+ * Convenience method to call Ollama API without streaming responses.
+ *
+ * Uses {@link #generateWithImageURLs(String, String, List, Options, OllamaStreamHandler)}
+ */
+ public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs,
+ Options options)
+ throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
+ return generateWithImageURLs(model, prompt, imageURLs, options, null);
}
@@ -487,10 +522,19 @@ public class OllamaAPI {
return Base64.getEncoder().encodeToString(bytes);
}
- private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel)
+ private OllamaResult generateSyncForOllamaRequestModel(
+ OllamaGenerateRequestModel ollamaRequestModel, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException {
- OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
- return requestCaller.callSync(ollamaRequestModel);
+ OllamaGenerateEndpointCaller requestCaller =
+ new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
+ OllamaResult result;
+ if (streamHandler != null) {
+ ollamaRequestModel.setStream(true);
+ result = requestCaller.call(ollamaRequestModel, streamHandler);
+ } else {
+ result = requestCaller.callSync(ollamaRequestModel);
+ }
+ return result;
}
/**
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java
new file mode 100644
index 0000000..a166bac
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java
@@ -0,0 +1,31 @@
+package io.github.amithkoujalgi.ollama4j.core.models.generate;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
+
+public class OllamaGenerateStreamObserver {
+
+ private OllamaStreamHandler streamHandler;
+
+ private List responseParts = new ArrayList<>();
+
+ private String message = "";
+
+ public OllamaGenerateStreamObserver(OllamaStreamHandler streamHandler) {
+ this.streamHandler = streamHandler;
+ }
+
+ public void notify(OllamaGenerateResponseModel currentResponsePart){
+ responseParts.add(currentResponsePart);
+ handleCurrentResponsePart(currentResponsePart);
+ }
+
+ protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart){
+ message = message + currentResponsePart.getResponse();
+ streamHandler.accept(message);
+ }
+
+
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
index ba55159..fe7fbec 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
@@ -1,18 +1,25 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
+import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
-
+import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
+import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
+import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
+import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
+import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
+ private OllamaGenerateStreamObserver streamObserver;
+
public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose);
}
@@ -27,6 +34,9 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
try {
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
responseBuffer.append(ollamaResponseModel.getResponse());
+ if(streamObserver != null) {
+ streamObserver.notify(ollamaResponseModel);
+ }
return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!",e);
@@ -34,7 +44,11 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
}
}
-
+ public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
+ throws OllamaBaseException, IOException, InterruptedException {
+ streamObserver = new OllamaGenerateStreamObserver(streamHandler);
+ return super.callSync(body);
+ }
}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
index b4a7e1d..c124f42 100644
--- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
@@ -109,6 +109,32 @@ class TestRealAPIs {
}
}
+ @Test
+ @Order(3)
+ void testAskModelWithDefaultOptionsStreamed() {
+ testEndpointReachability();
+ try {
+
+ StringBuffer sb = new StringBuffer("");
+
+ OllamaResult result = ollamaAPI.generate(config.getModel(),
+ "What is the capital of France? And what's France's connection with Mona Lisa?",
+ new OptionsBuilder().build(), (s) -> {
+ LOG.info(s);
+ String substring = s.substring(sb.toString().length(), s.length());
+ LOG.info(substring);
+ sb.append(substring);
+ });
+
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ assertEquals(sb.toString().trim(), result.getResponse().trim());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
@Test
@Order(3)
void testAskModelWithOptions() {
@@ -262,6 +288,30 @@ class TestRealAPIs {
}
}
+ @Test
+ @Order(3)
+ void testAskModelWithOptionsAndImageFilesStreamed() {
+ testEndpointReachability();
+ File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
+ try {
+ StringBuffer sb = new StringBuffer("");
+
+ OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
+ "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
+ LOG.info(s);
+ String substring = s.substring(sb.toString().length(), s.length());
+ LOG.info(substring);
+ sb.append(substring);
+ });
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ assertEquals(sb.toString().trim(), result.getResponse().trim());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
@Test
@Order(3)
void testAskModelWithOptionsAndImageURLs() {