mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 03:47:13 +02:00
Updated ask
and askAsync
responses to include responseTime
parameter
This commit is contained in:
parent
3a43d3e95c
commit
7579bbbc59
@ -105,15 +105,11 @@ public class OllamaAPI {
|
||||
public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException {
|
||||
String url = this.host + "/api/show";
|
||||
String jsonData = String.format("{\"name\": \"%s\"}", modelName);
|
||||
|
||||
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
|
||||
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
|
||||
int statusCode = response.statusCode();
|
||||
String responseBody = response.body();
|
||||
|
||||
if (statusCode == 200) {
|
||||
return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
|
||||
} else {
|
||||
@ -200,8 +196,9 @@ public class OllamaAPI {
|
||||
* @param promptText the prompt/question text
|
||||
* @return the response text from the model
|
||||
*/
|
||||
public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException {
|
||||
public OllamaResult ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
|
||||
long startTime = System.currentTimeMillis();
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
URI uri = URI.create(this.host + "/api/generate");
|
||||
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build();
|
||||
@ -221,7 +218,8 @@ public class OllamaAPI {
|
||||
if (statusCode != 200) {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseBuffer);
|
||||
} else {
|
||||
return responseBuffer.toString().trim();
|
||||
long endTime = System.currentTimeMillis();
|
||||
return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,7 @@ public class OllamaAsyncResultCallback extends Thread {
|
||||
private final Queue<String> queue = new LinkedList<>();
|
||||
private String result;
|
||||
private boolean isDone;
|
||||
private long responseTime = 0;
|
||||
|
||||
public OllamaAsyncResultCallback(HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) {
|
||||
this.client = client;
|
||||
@ -36,6 +37,7 @@ public class OllamaAsyncResultCallback extends Thread {
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
long startTime = System.currentTimeMillis();
|
||||
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build();
|
||||
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
int statusCode = response.statusCode();
|
||||
@ -55,6 +57,8 @@ public class OllamaAsyncResultCallback extends Thread {
|
||||
reader.close();
|
||||
this.isDone = true;
|
||||
this.result = responseBuffer.toString();
|
||||
long endTime = System.currentTimeMillis();
|
||||
responseTime = endTime - startTime;
|
||||
}
|
||||
if (statusCode != 200) {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
@ -80,4 +84,12 @@ public class OllamaAsyncResultCallback extends Thread {
|
||||
public Queue<String> getStream() {
|
||||
return queue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the response time in seconds.
|
||||
* @return response time in seconds
|
||||
*/
|
||||
public long getResponseTime() {
|
||||
return responseTime;
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,20 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||
|
||||
@SuppressWarnings("unused")
|
||||
public class OllamaResult {
|
||||
private String response;
|
||||
private long responseTime = 0;
|
||||
|
||||
public OllamaResult(String response, long responseTime) {
|
||||
this.response = response;
|
||||
this.responseTime = responseTime;
|
||||
}
|
||||
|
||||
public String getResponse() {
|
||||
return response;
|
||||
}
|
||||
|
||||
public long getResponseTime() {
|
||||
return responseTime;
|
||||
}
|
||||
}
|
@ -4,6 +4,7 @@ import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
|
||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
|
||||
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.Mockito;
|
||||
@ -100,7 +101,7 @@ public class TestMockedAPIs {
|
||||
String model = OllamaModelType.LLAMA2;
|
||||
String prompt = "some prompt text";
|
||||
try {
|
||||
when(ollamaAPI.ask(model, prompt)).thenReturn("");
|
||||
when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0));
|
||||
ollamaAPI.ask(model, prompt);
|
||||
verify(ollamaAPI, times(1)).ask(model, prompt);
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user