From abb76ad8671a293fe87e443ac7099489e9506540 Mon Sep 17 00:00:00 2001
From: Markus Klenke <markusklenke86@gmail.com>
Date: Fri, 16 Feb 2024 17:03:15 +0000
Subject: [PATCH] Adds streaming feature to Generate APIs

---
 .../ollama4j/core/OllamaAPI.java              | 62 ++++++++++++++++---
 .../OllamaGenerateStreamObserver.java         | 31 ++++++++++
 .../request/OllamaGenerateEndpointCaller.java | 18 +++++-
 .../integrationtests/TestRealAPIs.java        | 50 +++++++++++++++
 4 files changed, 150 insertions(+), 11 deletions(-)
 create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java

diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
index 20e9d3e..ec772f1 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
@@ -342,13 +342,24 @@ public class OllamaAPI {
    * @param options the Options object - <a
    *     href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
    *     details on the options</a>
+   * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
    * @return OllamaResult that includes response text and time taken for response
    */
-  public OllamaResult generate(String model, String prompt, Options options)
+  public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler)
       throws OllamaBaseException, IOException, InterruptedException {
     OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
     ollamaRequestModel.setOptions(options.getOptionsMap());
-    return generateSyncForOllamaRequestModel(ollamaRequestModel);
+    return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler);
+  }
+
+  /**
+   * Convenience method to call Ollama API without streaming responses.
+   * 
+   * Uses {@link #generate(String, String, Options, OllamaStreamHandler)}
+   */
+  public OllamaResult generate(String model, String prompt, Options options)
+  throws OllamaBaseException, IOException, InterruptedException {
+    return generate(model, prompt, options,null);
   }
 
   /**
@@ -381,10 +392,11 @@ public class OllamaAPI {
    * @param options the Options object - <a
    *     href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
    *     details on the options</a>
+   * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
    * @return OllamaResult that includes response text and time taken for response
    */
   public OllamaResult generateWithImageFiles(
-      String model, String prompt, List<File> imageFiles, Options options)
+      String model, String prompt, List<File> imageFiles, Options options, OllamaStreamHandler streamHandler)
       throws OllamaBaseException, IOException, InterruptedException {
     List<String> images = new ArrayList<>();
     for (File imageFile : imageFiles) {
@@ -392,9 +404,20 @@ public class OllamaAPI {
     }
     OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
     ollamaRequestModel.setOptions(options.getOptionsMap());
-    return generateSyncForOllamaRequestModel(ollamaRequestModel);
+    return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler);
   }
 
+   /**
+   * Convenience method to call Ollama API without streaming responses.
+   * 
+   * Uses {@link #generateWithImageFiles(String, String, List, Options, OllamaStreamHandler)}
+   */
+  public OllamaResult generateWithImageFiles(
+    String model, String prompt, List<File> imageFiles, Options options)
+    throws OllamaBaseException, IOException, InterruptedException{
+      return generateWithImageFiles(model, prompt, imageFiles, options, null);
+}
+
   /**
    * With one or more image URLs, ask a question to a model running on Ollama server. This is a
    * sync/blocking call.
@@ -405,10 +428,11 @@ public class OllamaAPI {
    * @param options the Options object - <a
    *     href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
    *     details on the options</a>
+   * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
    * @return OllamaResult that includes response text and time taken for response
    */
   public OllamaResult generateWithImageURLs(
-      String model, String prompt, List<String> imageURLs, Options options)
+      String model, String prompt, List<String> imageURLs, Options options, OllamaStreamHandler streamHandler)
       throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
     List<String> images = new ArrayList<>();
     for (String imageURL : imageURLs) {
@@ -416,7 +440,18 @@ public class OllamaAPI {
     }
     OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
     ollamaRequestModel.setOptions(options.getOptionsMap());
-    return generateSyncForOllamaRequestModel(ollamaRequestModel);
+    return generateSyncForOllamaRequestModel(ollamaRequestModel,streamHandler);
+  }
+
+  /**
+   * Convenience method to call Ollama API without streaming responses.
+   * 
+   * Uses {@link #generateWithImageURLs(String, String, List, Options, OllamaStreamHandler)}
+   */
+  public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs,
+      Options options)
+      throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
+    return generateWithImageURLs(model, prompt, imageURLs, options, null);
   }
 
 
@@ -487,10 +522,19 @@ public class OllamaAPI {
     return Base64.getEncoder().encodeToString(bytes);
   }
 
-  private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel)
+  private OllamaResult generateSyncForOllamaRequestModel(
+      OllamaGenerateRequestModel ollamaRequestModel, OllamaStreamHandler streamHandler)
       throws OllamaBaseException, IOException, InterruptedException {
-        OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
-        return requestCaller.callSync(ollamaRequestModel);
+    OllamaGenerateEndpointCaller requestCaller =
+        new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
+    OllamaResult result;
+    if (streamHandler != null) {
+      ollamaRequestModel.setStream(true);
+      result = requestCaller.call(ollamaRequestModel, streamHandler);
+    } else {
+      result = requestCaller.callSync(ollamaRequestModel);
+    }
+    return result;
   }
 
   /**
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java
new file mode 100644
index 0000000..a166bac
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateStreamObserver.java
@@ -0,0 +1,31 @@
+package io.github.amithkoujalgi.ollama4j.core.models.generate;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
+
+public class OllamaGenerateStreamObserver {
+
+    private OllamaStreamHandler streamHandler;
+
+    private List<OllamaGenerateResponseModel> responseParts = new ArrayList<>();
+
+    private String message = "";
+
+    public OllamaGenerateStreamObserver(OllamaStreamHandler streamHandler) {
+        this.streamHandler = streamHandler;
+    }
+
+    public void notify(OllamaGenerateResponseModel currentResponsePart){
+        responseParts.add(currentResponsePart);
+        handleCurrentResponsePart(currentResponsePart);
+    }
+    
+    protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart){
+        message = message + currentResponsePart.getResponse();
+        streamHandler.accept(message);
+    }
+
+
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
index ba55159..fe7fbec 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
@@ -1,18 +1,25 @@
 package io.github.amithkoujalgi.ollama4j.core.models.request;
 
+import java.io.IOException;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import com.fasterxml.jackson.core.JsonProcessingException;
-
+import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
+import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
 import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
+import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
 import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
+import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
+import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
 import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
 
 public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
 
     private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
 
+    private OllamaGenerateStreamObserver streamObserver;
+
     public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
         super(host, basicAuth, requestTimeoutSeconds, verbose);   
     }
@@ -27,6 +34,9 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
                 try {
                     OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
                     responseBuffer.append(ollamaResponseModel.getResponse());
+                    if(streamObserver != null) {
+                        streamObserver.notify(ollamaResponseModel);
+                    }
                     return ollamaResponseModel.isDone();
                 } catch (JsonProcessingException e) {
                     LOG.error("Error parsing the Ollama chat response!",e);
@@ -34,7 +44,11 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
                 }         
     }
 
-    
+    public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
+        throws OllamaBaseException, IOException, InterruptedException {
+    streamObserver = new OllamaGenerateStreamObserver(streamHandler);
+    return super.callSync(body);
+    }
     
     
 }
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
index b4a7e1d..c124f42 100644
--- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
@@ -109,6 +109,32 @@ class TestRealAPIs {
     }
   }
 
+  @Test
+  @Order(3)
+  void testAskModelWithDefaultOptionsStreamed() {
+    testEndpointReachability();
+    try {
+
+      StringBuffer sb = new StringBuffer("");
+
+      OllamaResult result = ollamaAPI.generate(config.getModel(),
+          "What is the capital of France? And what's France's connection with Mona Lisa?",
+          new OptionsBuilder().build(), (s) -> {
+            LOG.info(s);
+            String substring = s.substring(sb.toString().length(), s.length());
+            LOG.info(substring);
+            sb.append(substring);
+          });
+
+      assertNotNull(result);
+      assertNotNull(result.getResponse());
+      assertFalse(result.getResponse().isEmpty());
+      assertEquals(sb.toString().trim(), result.getResponse().trim());
+    } catch (IOException | OllamaBaseException | InterruptedException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
   @Test
   @Order(3)
   void testAskModelWithOptions() {
@@ -262,6 +288,30 @@ class TestRealAPIs {
     }
   }
 
+  @Test
+  @Order(3)
+  void testAskModelWithOptionsAndImageFilesStreamed() {
+    testEndpointReachability();
+    File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
+    try {
+      StringBuffer sb = new StringBuffer("");
+
+      OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
+          "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
+            LOG.info(s);
+            String substring = s.substring(sb.toString().length(), s.length());
+            LOG.info(substring);
+            sb.append(substring);
+          });
+      assertNotNull(result);
+      assertNotNull(result.getResponse());
+      assertFalse(result.getResponse().isEmpty());
+      assertEquals(sb.toString().trim(), result.getResponse().trim());
+    } catch (IOException | OllamaBaseException | InterruptedException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
   @Test
   @Order(3)
   void testAskModelWithOptionsAndImageURLs() {