Updated ask and askAsync responses to include responseTime parameter

This commit is contained in:
Amith Koujalgi 2023-11-16 00:48:59 +05:30
parent 3a43d3e95c
commit 7579bbbc59
4 changed files with 38 additions and 7 deletions

View File

@ -105,15 +105,11 @@ public class OllamaAPI {
public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException {
String url = this.host + "/api/show";
String jsonData = String.format("{\"name\": \"%s\"}", modelName);
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseBody = response.body();
if (statusCode == 200) {
return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
} else {
@ -200,8 +196,9 @@ public class OllamaAPI {
* @param promptText the prompt/question text
* @return the response text from the model
*/
public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException {
public OllamaResult ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate");
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build();
@ -221,7 +218,8 @@ public class OllamaAPI {
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseBuffer);
} else {
return responseBuffer.toString().trim();
long endTime = System.currentTimeMillis();
return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime);
}
}

View File

@ -23,6 +23,7 @@ public class OllamaAsyncResultCallback extends Thread {
private final Queue<String> queue = new LinkedList<>();
private String result;
private boolean isDone;
private long responseTime = 0;
public OllamaAsyncResultCallback(HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) {
this.client = client;
@ -36,6 +37,7 @@ public class OllamaAsyncResultCallback extends Thread {
@Override
public void run() {
try {
long startTime = System.currentTimeMillis();
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build();
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
@ -55,6 +57,8 @@ public class OllamaAsyncResultCallback extends Thread {
reader.close();
this.isDone = true;
this.result = responseBuffer.toString();
long endTime = System.currentTimeMillis();
responseTime = endTime - startTime;
}
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
@ -80,4 +84,12 @@ public class OllamaAsyncResultCallback extends Thread {
public Queue<String> getStream() {
return queue;
}
/**
* Returns the response time in seconds.
* @return response time in seconds
*/
public long getResponseTime() {
return responseTime;
}
}

View File

@ -0,0 +1,20 @@
package io.github.amithkoujalgi.ollama4j.core.models;
@SuppressWarnings("unused")
public class OllamaResult {
private String response;
private long responseTime = 0;
public OllamaResult(String response, long responseTime) {
this.response = response;
this.responseTime = responseTime;
}
public String getResponse() {
return response;
}
public long getResponseTime() {
return responseTime;
}
}

View File

@ -4,6 +4,7 @@ import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
@ -100,7 +101,7 @@ public class TestMockedAPIs {
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.ask(model, prompt)).thenReturn("");
when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0));
ollamaAPI.ask(model, prompt);
verify(ollamaAPI, times(1)).ask(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) {