mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 03:47:13 +02:00
Updated ask
and askAsync
responses to include responseTime
parameter
This commit is contained in:
parent
3a43d3e95c
commit
7579bbbc59
@ -105,15 +105,11 @@ public class OllamaAPI {
|
|||||||
public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException {
|
public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException {
|
||||||
String url = this.host + "/api/show";
|
String url = this.host + "/api/show";
|
||||||
String jsonData = String.format("{\"name\": \"%s\"}", modelName);
|
String jsonData = String.format("{\"name\": \"%s\"}", modelName);
|
||||||
|
|
||||||
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
|
HttpRequest request = HttpRequest.newBuilder().uri(URI.create(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
|
||||||
|
|
||||||
HttpClient client = HttpClient.newHttpClient();
|
HttpClient client = HttpClient.newHttpClient();
|
||||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||||
|
|
||||||
int statusCode = response.statusCode();
|
int statusCode = response.statusCode();
|
||||||
String responseBody = response.body();
|
String responseBody = response.body();
|
||||||
|
|
||||||
if (statusCode == 200) {
|
if (statusCode == 200) {
|
||||||
return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
|
return Utils.getObjectMapper().readValue(responseBody, ModelDetail.class);
|
||||||
} else {
|
} else {
|
||||||
@ -200,8 +196,9 @@ public class OllamaAPI {
|
|||||||
* @param promptText the prompt/question text
|
* @param promptText the prompt/question text
|
||||||
* @return the response text from the model
|
* @return the response text from the model
|
||||||
*/
|
*/
|
||||||
public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException {
|
public OllamaResult ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException, InterruptedException {
|
||||||
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
|
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
HttpClient httpClient = HttpClient.newHttpClient();
|
HttpClient httpClient = HttpClient.newHttpClient();
|
||||||
URI uri = URI.create(this.host + "/api/generate");
|
URI uri = URI.create(this.host + "/api/generate");
|
||||||
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build();
|
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build();
|
||||||
@ -221,7 +218,8 @@ public class OllamaAPI {
|
|||||||
if (statusCode != 200) {
|
if (statusCode != 200) {
|
||||||
throw new OllamaBaseException(statusCode + " - " + responseBuffer);
|
throw new OllamaBaseException(statusCode + " - " + responseBuffer);
|
||||||
} else {
|
} else {
|
||||||
return responseBuffer.toString().trim();
|
long endTime = System.currentTimeMillis();
|
||||||
|
return new OllamaResult(responseBuffer.toString().trim(), endTime - startTime);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ public class OllamaAsyncResultCallback extends Thread {
|
|||||||
private final Queue<String> queue = new LinkedList<>();
|
private final Queue<String> queue = new LinkedList<>();
|
||||||
private String result;
|
private String result;
|
||||||
private boolean isDone;
|
private boolean isDone;
|
||||||
|
private long responseTime = 0;
|
||||||
|
|
||||||
public OllamaAsyncResultCallback(HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) {
|
public OllamaAsyncResultCallback(HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) {
|
||||||
this.client = client;
|
this.client = client;
|
||||||
@ -36,6 +37,7 @@ public class OllamaAsyncResultCallback extends Thread {
|
|||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
try {
|
try {
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build();
|
HttpRequest request = HttpRequest.newBuilder(uri).POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header("Content-Type", "application/json").build();
|
||||||
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||||
int statusCode = response.statusCode();
|
int statusCode = response.statusCode();
|
||||||
@ -55,6 +57,8 @@ public class OllamaAsyncResultCallback extends Thread {
|
|||||||
reader.close();
|
reader.close();
|
||||||
this.isDone = true;
|
this.isDone = true;
|
||||||
this.result = responseBuffer.toString();
|
this.result = responseBuffer.toString();
|
||||||
|
long endTime = System.currentTimeMillis();
|
||||||
|
responseTime = endTime - startTime;
|
||||||
}
|
}
|
||||||
if (statusCode != 200) {
|
if (statusCode != 200) {
|
||||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||||
@ -80,4 +84,12 @@ public class OllamaAsyncResultCallback extends Thread {
|
|||||||
public Queue<String> getStream() {
|
public Queue<String> getStream() {
|
||||||
return queue;
|
return queue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the response time in seconds.
|
||||||
|
* @return response time in seconds
|
||||||
|
*/
|
||||||
|
public long getResponseTime() {
|
||||||
|
return responseTime;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,20 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||||
|
|
||||||
|
@SuppressWarnings("unused")
|
||||||
|
public class OllamaResult {
|
||||||
|
private String response;
|
||||||
|
private long responseTime = 0;
|
||||||
|
|
||||||
|
public OllamaResult(String response, long responseTime) {
|
||||||
|
this.response = response;
|
||||||
|
this.responseTime = responseTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getResponse() {
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getResponseTime() {
|
||||||
|
return responseTime;
|
||||||
|
}
|
||||||
|
}
|
@ -4,6 +4,7 @@ import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
|
|||||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
|
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
|
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
|
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.mockito.Mockito;
|
import org.mockito.Mockito;
|
||||||
@ -100,7 +101,7 @@ public class TestMockedAPIs {
|
|||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.ask(model, prompt)).thenReturn("");
|
when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0));
|
||||||
ollamaAPI.ask(model, prompt);
|
ollamaAPI.ask(model, prompt);
|
||||||
verify(ollamaAPI, times(1)).ask(model, prompt);
|
verify(ollamaAPI, times(1)).ask(model, prompt);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user