diff --git a/docs/docs/apis-generate/generate-async.md b/docs/docs/apis-generate/generate-async.md index 49f556f..d5c1c27 100644 --- a/docs/docs/apis-generate/generate-async.md +++ b/docs/docs/apis-generate/generate-async.md @@ -1,42 +1,46 @@ --- -sidebar_position: 3 +sidebar_position: 2 --- # Generate - Async This API lets you ask questions to the LLMs in a asynchronous way. -These APIs correlate to -the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) APIs. +This is particularly helpful when you want to issue a generate request to the LLM and collect the response in the +background (such as threads) without blocking your code until the response arrives from the model. + +This API corresponds to +the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) API. ```java public class Main { - public static void main(String[] args) { - + public static void main(String[] args) throws Exception { String host = "http://localhost:11434/"; - OllamaAPI ollamaAPI = new OllamaAPI(host); + ollamaAPI.setRequestTimeoutSeconds(60); + String prompt = "List all cricket world cup teams of 2019."; + OllamaAsyncResultStreamer streamer = ollamaAPI.generateAsync(OllamaModelType.LLAMA3, prompt, false); - String prompt = "Who are you?"; + // Set the poll interval according to your needs. + // Smaller the poll interval, more frequently you receive the token. + int pollIntervalMilliseconds = 1000; - OllamaAsyncResultCallback callback = ollamaAPI.generateAsync(OllamaModelType.LLAMA2, prompt); - - while (!callback.isComplete() || !callback.getStream().isEmpty()) { - // poll for data from the response stream - String result = callback.getStream().poll(); - if (result != null) { - System.out.print(result); + while (true) { + String tokens = streamer.getStream().poll(); + System.out.print(tokens); + if (!streamer.isAlive()) { + break; } - Thread.sleep(100); + Thread.sleep(pollIntervalMilliseconds); } + + System.out.println("\n------------------------"); + System.out.println("Complete Response:"); + System.out.println("------------------------"); + + System.out.println(streamer.getResult()); } } ``` -You will get a response similar to: - -> I am LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in a conversational -> manner. I am trained on a massive dataset of text from the internet and can generate human-like responses to a wide -> range of topics and questions. I can be used to create chatbots, virtual assistants, and other applications that -> require -> natural language understanding and generation capabilities. \ No newline at end of file +You will get a steaming response. \ No newline at end of file diff --git a/docs/docs/apis-generate/generate-with-image-files.md b/docs/docs/apis-generate/generate-with-image-files.md index 4406981..1e1f9f9 100644 --- a/docs/docs/apis-generate/generate-with-image-files.md +++ b/docs/docs/apis-generate/generate-with-image-files.md @@ -5,8 +5,8 @@ sidebar_position: 4 # Generate - With Image Files This API lets you ask questions along with the image files to the LLMs. -These APIs correlate to -the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) APIs. +This API corresponds to +the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) API. :::note diff --git a/docs/docs/apis-generate/generate-with-image-urls.md b/docs/docs/apis-generate/generate-with-image-urls.md index 587e8f0..2fd3941 100644 --- a/docs/docs/apis-generate/generate-with-image-urls.md +++ b/docs/docs/apis-generate/generate-with-image-urls.md @@ -5,8 +5,8 @@ sidebar_position: 5 # Generate - With Image URLs This API lets you ask questions along with the image files to the LLMs. -These APIs correlate to -the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) APIs. +This API corresponds to +the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) API. :::note diff --git a/docs/docs/apis-generate/generate-with-tools.md b/docs/docs/apis-generate/generate-with-tools.md index 0ca142a..6b7cca6 100644 --- a/docs/docs/apis-generate/generate-with-tools.md +++ b/docs/docs/apis-generate/generate-with-tools.md @@ -1,12 +1,12 @@ --- -sidebar_position: 2 +sidebar_position: 3 --- # Generate - With Tools This API lets you perform [function calling](https://docs.mistral.ai/capabilities/function_calling/) using LLMs in a synchronous way. -This API correlates to +This API corresponds to the [generate](https://github.com/ollama/ollama/blob/main/docs/api.md#request-raw-mode) API with `raw` mode. :::note diff --git a/docs/docs/apis-generate/generate.md b/docs/docs/apis-generate/generate.md index aed106e..0469f26 100644 --- a/docs/docs/apis-generate/generate.md +++ b/docs/docs/apis-generate/generate.md @@ -5,8 +5,8 @@ sidebar_position: 1 # Generate - Sync This API lets you ask questions to the LLMs in a synchronous way. -These APIs correlate to -the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) APIs. +This API corresponds to +the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) API. Use the `OptionBuilder` to build the `Options` object with [extra parameters](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 80654ae..340fcb9 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -360,15 +360,15 @@ public class OllamaAPI { } /** - * Convenience method to call Ollama API without streaming responses. + * Generates response using the specified AI model and prompt (in blocking mode). *

* Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)} * - * @param model Model to use - * @param prompt Prompt text + * @param model The name or identifier of the AI model to use for generating the response. + * @param prompt The input text or prompt to provide to the AI model. * @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context. - * @param options Additional Options - * @return OllamaResult + * @param options Additional options or configurations to use when generating the response. + * @return {@link OllamaResult} */ public OllamaResult generate(String model, String prompt, boolean raw, Options options) throws OllamaBaseException, IOException, InterruptedException { @@ -376,6 +376,20 @@ public class OllamaAPI { } + /** + * Generates response using the specified AI model and prompt (in blocking mode), and then invokes a set of tools + * on the generated response. + * + * @param model The name or identifier of the AI model to use for generating the response. + * @param prompt The input text or prompt to provide to the AI model. + * @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context. + * @param options Additional options or configurations to use when generating the response. + * @return {@link OllamaToolsResult} An OllamaToolsResult object containing the response from the AI model and the results of invoking the tools on that output. + * @throws OllamaBaseException If there is an error related to the Ollama API or service. + * @throws IOException If there is an error related to input/output operations. + * @throws InterruptedException If the method is interrupted while waiting for the AI model + * to generate the response or for the tools to be invoked. + */ public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options) throws OllamaBaseException, IOException, InterruptedException { OllamaToolsResult toolResult = new OllamaToolsResult(); @@ -402,15 +416,15 @@ public class OllamaAPI { * @param prompt the prompt/question text * @return the ollama async result callback handle */ - public OllamaAsyncResultCallback generateAsync(String model, String prompt, boolean raw) { + public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw) { OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); ollamaRequestModel.setRaw(raw); URI uri = URI.create(this.host + "/api/generate"); - OllamaAsyncResultCallback ollamaAsyncResultCallback = - new OllamaAsyncResultCallback( + OllamaAsyncResultStreamer ollamaAsyncResultStreamer = + new OllamaAsyncResultStreamer( getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds); - ollamaAsyncResultCallback.start(); - return ollamaAsyncResultCallback; + ollamaAsyncResultStreamer.start(); + return ollamaAsyncResultStreamer; } /** @@ -508,7 +522,7 @@ public class OllamaAPI { * Hint: the OllamaChatRequestModel#getStream() property is not implemented. * * @param request request object to be sent to the server - * @return + * @return {@link OllamaChatResult} * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read * @throws InterruptedException in case the server is not reachable or network issues happen @@ -524,7 +538,7 @@ public class OllamaAPI { * * @param request request object to be sent to the server * @param streamHandler callback handler to handle the last message from stream (caution: all previous messages from stream will be concatenated) - * @return + * @return {@link OllamaChatResult} * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read * @throws InterruptedException in case the server is not reachable or network issues happen @@ -541,6 +555,10 @@ public class OllamaAPI { return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages()); } + public void registerTool(MistralTools.ToolSpecification toolSpecification) { + ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition()); + } + // technical private methods // private static String encodeFileToBase64(File file) throws IOException { @@ -603,10 +621,6 @@ public class OllamaAPI { } - public void registerTool(MistralTools.ToolSpecification toolSpecification) { - ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition()); - } - private Object invokeTool(ToolDef toolDef) { try { String methodName = toolDef.getName(); diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaResultStream.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaResultStream.java new file mode 100644 index 0000000..21a15b1 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaResultStream.java @@ -0,0 +1,18 @@ +package io.github.amithkoujalgi.ollama4j.core; + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.Queue; + +public class OllamaResultStream extends LinkedList implements Queue { + @Override + public String poll() { + StringBuilder tokens = new StringBuilder(); + Iterator iterator = this.listIterator(); + while (iterator.hasNext()) { + tokens.append(iterator.next()); + iterator.remove(); + } + return tokens.toString(); + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java deleted file mode 100644 index 136f1c6..0000000 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java +++ /dev/null @@ -1,143 +0,0 @@ -package io.github.amithkoujalgi.ollama4j.core.models; - -import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; -import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; -import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel; -import io.github.amithkoujalgi.ollama4j.core.utils.Utils; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.LinkedList; -import java.util.Queue; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.Getter; - -@Data -@EqualsAndHashCode(callSuper = true) -@SuppressWarnings("unused") -public class OllamaAsyncResultCallback extends Thread { - private final HttpRequest.Builder requestBuilder; - private final OllamaGenerateRequestModel ollamaRequestModel; - private final Queue queue = new LinkedList<>(); - private String result; - private boolean isDone; - - /** - * -- GETTER -- Returns the status of the request. Indicates if the request was successful or a - * failure. If the request was a failure, the `getResponse()` method will return the error - * message. - */ - @Getter private boolean succeeded; - - private long requestTimeoutSeconds; - - /** - * -- GETTER -- Returns the HTTP response status code for the request that was made to Ollama - * server. - */ - @Getter private int httpStatusCode; - - /** -- GETTER -- Returns the response time in milliseconds. */ - @Getter private long responseTime = 0; - - public OllamaAsyncResultCallback( - HttpRequest.Builder requestBuilder, - OllamaGenerateRequestModel ollamaRequestModel, - long requestTimeoutSeconds) { - this.requestBuilder = requestBuilder; - this.ollamaRequestModel = ollamaRequestModel; - this.isDone = false; - this.result = ""; - this.queue.add(""); - this.requestTimeoutSeconds = requestTimeoutSeconds; - } - - @Override - public void run() { - HttpClient httpClient = HttpClient.newHttpClient(); - try { - long startTime = System.currentTimeMillis(); - HttpRequest request = - requestBuilder - .POST( - HttpRequest.BodyPublishers.ofString( - Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) - .header("Content-Type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .build(); - HttpResponse response = - httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); - int statusCode = response.statusCode(); - this.httpStatusCode = statusCode; - - InputStream responseBodyStream = response.body(); - try (BufferedReader reader = - new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { - String line; - StringBuilder responseBuffer = new StringBuilder(); - while ((line = reader.readLine()) != null) { - if (statusCode == 404) { - OllamaErrorResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); - queue.add(ollamaResponseModel.getError()); - responseBuffer.append(ollamaResponseModel.getError()); - } else { - OllamaGenerateResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); - queue.add(ollamaResponseModel.getResponse()); - if (!ollamaResponseModel.isDone()) { - responseBuffer.append(ollamaResponseModel.getResponse()); - } - } - } - - this.isDone = true; - this.succeeded = true; - this.result = responseBuffer.toString(); - long endTime = System.currentTimeMillis(); - responseTime = endTime - startTime; - } - if (statusCode != 200) { - throw new OllamaBaseException(this.result); - } - } catch (IOException | InterruptedException | OllamaBaseException e) { - this.isDone = true; - this.succeeded = false; - this.result = "[FAILED] " + e.getMessage(); - } - } - - /** - * Returns the status of the thread. This does not indicate that the request was successful or a - * failure, rather it is just a status flag to indicate if the thread is active or ended. - * - * @return boolean - status - */ - public boolean isComplete() { - return isDone; - } - - /** - * Returns the final completion/response when the execution completes. Does not return intermediate results. - * - * @return String completion/response text - */ - public String getResponse() { - return result; - } - - public Queue getStream() { - return queue; - } - - public void setRequestTimeoutSeconds(long requestTimeoutSeconds) { - this.requestTimeoutSeconds = requestTimeoutSeconds; - } -} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultStreamer.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultStreamer.java new file mode 100644 index 0000000..66dde93 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultStreamer.java @@ -0,0 +1,124 @@ +package io.github.amithkoujalgi.ollama4j.core.models; + +import io.github.amithkoujalgi.ollama4j.core.OllamaResultStream; +import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.time.Duration; + +@Data +@EqualsAndHashCode(callSuper = true) +@SuppressWarnings("unused") +public class OllamaAsyncResultStreamer extends Thread { + private final HttpRequest.Builder requestBuilder; + private final OllamaGenerateRequestModel ollamaRequestModel; + private final OllamaResultStream stream = new OllamaResultStream(); + private String completeResponse; + + + /** + * -- GETTER -- Returns the status of the request. Indicates if the request was successful or a + * failure. If the request was a failure, the `getResponse()` method will return the error + * message. + */ + @Getter + private boolean succeeded; + + @Setter + private long requestTimeoutSeconds; + + /** + * -- GETTER -- Returns the HTTP response status code for the request that was made to Ollama + * server. + */ + @Getter + private int httpStatusCode; + + /** + * -- GETTER -- Returns the response time in milliseconds. + */ + @Getter + private long responseTime = 0; + + public OllamaAsyncResultStreamer( + HttpRequest.Builder requestBuilder, + OllamaGenerateRequestModel ollamaRequestModel, + long requestTimeoutSeconds) { + this.requestBuilder = requestBuilder; + this.ollamaRequestModel = ollamaRequestModel; + this.completeResponse = ""; + this.stream.add(""); + this.requestTimeoutSeconds = requestTimeoutSeconds; + } + + @Override + public void run() { + ollamaRequestModel.setStream(true); + HttpClient httpClient = HttpClient.newHttpClient(); + try { + long startTime = System.currentTimeMillis(); + HttpRequest request = + requestBuilder + .POST( + HttpRequest.BodyPublishers.ofString( + Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) + .header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) + .build(); + HttpResponse response = + httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + int statusCode = response.statusCode(); + this.httpStatusCode = statusCode; + + InputStream responseBodyStream = response.body(); + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + String line; + StringBuilder responseBuffer = new StringBuilder(); + while ((line = reader.readLine()) != null) { + if (statusCode == 404) { + OllamaErrorResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); + stream.add(ollamaResponseModel.getError()); + responseBuffer.append(ollamaResponseModel.getError()); + } else { + OllamaGenerateResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); + String res = ollamaResponseModel.getResponse(); + stream.add(res); + if (!ollamaResponseModel.isDone()) { + responseBuffer.append(res); + } + } + } + + this.succeeded = true; + this.completeResponse = responseBuffer.toString(); + long endTime = System.currentTimeMillis(); + responseTime = endTime - startTime; + } + if (statusCode != 200) { + throw new OllamaBaseException(this.completeResponse); + } + } catch (IOException | InterruptedException | OllamaBaseException e) { + this.succeeded = false; + this.completeResponse = "[FAILED] " + e.getMessage(); + } + } + +} + diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java index c5d60e1..3b1613f 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java @@ -3,7 +3,7 @@ package io.github.amithkoujalgi.ollama4j.unittests; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; -import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultStreamer; import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; @@ -157,7 +157,7 @@ class TestMockedAPIs { String model = OllamaModelType.LLAMA2; String prompt = "some prompt text"; when(ollamaAPI.generateAsync(model, prompt, false)) - .thenReturn(new OllamaAsyncResultCallback(null, null, 3)); + .thenReturn(new OllamaAsyncResultStreamer(null, null, 3)); ollamaAPI.generateAsync(model, prompt, false); verify(ollamaAPI, times(1)).generateAsync(model, prompt, false); }