From 7579bbbc596c970ec8d163c1028aa8b4a1cdf4e6 Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Thu, 16 Nov 2023 00:48:59 +0530 Subject: [PATCH] Updated `ask` and `askAsync` responses to include `responseTime` parameter --- .../ollama4j/core/OllamaAPI.java | 10 ++++------ .../models/OllamaAsyncResultCallback.java | 12 +++++++++++ .../ollama4j/core/models/OllamaResult.java | 20 +++++++++++++++++++ .../ollama4j/TestMockedAPIs.java | 3 ++- 4 files changed, 38 insertions(+), 7 deletions(-) create mode 100644 src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResult.java diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 5e69f89..0b85644 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -105,15 +105,11 @@ public class OllamaAPI { public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException { String url = this.host + "/api/show"; String jsonData = String.format("{\"name\": \"%s\"}", modelName); - HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); - HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); - int statusCode = response.statusCode(); String responseBody = response.body(); - if (statusCode == 200) { return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class); } else { @@ -200,8 +196,9 @@ public class OllamaAPI { * @param promptText the prompt/question text * @return the response text from the model */ - public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException { OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); + long startTime = System.currentTimeMillis(); HttpClient httpClient = HttpClient.newHttpClient(); URI uri = URI.create(this.host + "/api/generate"); HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build(); @@ -221,7 +218,8 @@ public class OllamaAPI { if (statusCode != 200) { throw new OllamaBaseException(statusCode + " - " + responseBuffer); } else { - return responseBuffer.toString().trim(); + long endTime = System.currentTimeMillis(); + return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime); } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java index 7524eaf..0d92d5c 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java @@ -23,6 +23,7 @@ public class OllamaAsyncResultCallback extends Thread { private final Queue queue = new LinkedList<>(); private String result; private boolean isDone; + private long responseTime = 0; public OllamaAsyncResultCallback(HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) { this.client = client; @@ -36,6 +37,7 @@ public class OllamaAsyncResultCallback extends Thread { @Override public void run() { try { + long startTime = System.currentTimeMillis(); HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); @@ -55,6 +57,8 @@ public class OllamaAsyncResultCallback extends Thread { reader.close(); this.isDone = true; this.result = responseBuffer.toString(); + long endTime = System.currentTimeMillis(); + responseTime = endTime - startTime; } if (statusCode != 200) { throw new OllamaBaseException(statusCode + " - " + responseString); @@ -80,4 +84,12 @@ public class OllamaAsyncResultCallback extends Thread { public Queue getStream() { return queue; } + + /** + * Returns the response time in seconds. + * @return response time in seconds + */ + public long getResponseTime() { + return responseTime; + } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResult.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResult.java new file mode 100644 index 0000000..ac5f67b --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResult.java @@ -0,0 +1,20 @@ +package io.github.amithkoujalgi.ollama4j.core.models; + +@SuppressWarnings("unused") +public class OllamaResult { + private String response; + private long responseTime = 0; + + public OllamaResult(String response, long responseTime) { + this.response = response; + this.responseTime = responseTime; + } + + public String getResponse() { + return response; + } + + public long getResponseTime() { + return responseTime; + } +} diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java index 704d49a..f9a2f09 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java @@ -4,6 +4,7 @@ import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; import org.junit.jupiter.api.Test; import org.mockito.Mockito; @@ -100,7 +101,7 @@ public class TestMockedAPIs { String model = OllamaModelType.LLAMA2; String prompt = "some prompt text"; try { - when(ollamaAPI.ask(model, prompt)).thenReturn(""); + when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0)); ollamaAPI.ask(model, prompt); verify(ollamaAPI, times(1)).ask(model, prompt); } catch (IOException | OllamaBaseException | InterruptedException e) {