Updated tests

This commit is contained in:
Amith Koujalgi 2023-12-14 20:10:18 +05:30
parent 548895bbce
commit e2a88b06c5
5 changed files with 77 additions and 8 deletions

View File

@ -13,6 +13,7 @@ import java.net.http.HttpClient;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List; import java.util.List;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -23,6 +24,7 @@ public class OllamaAPI {
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
private final String host; private final String host;
private long requestTimeoutSeconds = 3;
private boolean verbose = false; private boolean verbose = false;
/** /**
@ -61,6 +63,7 @@ public class OllamaAPI {
.uri(new URI(url)) .uri(new URI(url))
.header("Accept", "application/json") .header("Accept", "application/json")
.header("Content-type", "application/json") .header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.GET() .GET()
.build(); .build();
HttpResponse<String> response = HttpResponse<String> response =
@ -92,6 +95,7 @@ public class OllamaAPI {
.POST(HttpRequest.BodyPublishers.ofString(jsonData)) .POST(HttpRequest.BodyPublishers.ofString(jsonData))
.header("Accept", "application/json") .header("Accept", "application/json")
.header("Content-type", "application/json") .header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.build(); .build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<InputStream> response = HttpResponse<InputStream> response =
@ -130,6 +134,7 @@ public class OllamaAPI {
.uri(URI.create(url)) .uri(URI.create(url))
.header("Accept", "application/json") .header("Accept", "application/json")
.header("Content-type", "application/json") .header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.POST(HttpRequest.BodyPublishers.ofString(jsonData)) .POST(HttpRequest.BodyPublishers.ofString(jsonData))
.build(); .build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
@ -160,6 +165,7 @@ public class OllamaAPI {
.uri(URI.create(url)) .uri(URI.create(url))
.header("Accept", "application/json") .header("Accept", "application/json")
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.build(); .build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
@ -196,6 +202,7 @@ public class OllamaAPI {
.method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.header("Accept", "application/json") .header("Accept", "application/json")
.header("Content-type", "application/json") .header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.build(); .build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
@ -226,6 +233,7 @@ public class OllamaAPI {
.uri(URI.create(url)) .uri(URI.create(url))
.header("Accept", "application/json") .header("Accept", "application/json")
.header("Content-type", "application/json") .header("Content-type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.POST(HttpRequest.BodyPublishers.ofString(jsonData)) .POST(HttpRequest.BodyPublishers.ofString(jsonData))
.build(); .build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
@ -259,6 +267,7 @@ public class OllamaAPI {
HttpRequest.BodyPublishers.ofString( HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.build(); .build();
HttpResponse<InputStream> response = HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
@ -304,7 +313,7 @@ public class OllamaAPI {
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback = OllamaAsyncResultCallback ollamaAsyncResultCallback =
new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel); new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel, requestTimeoutSeconds);
ollamaAsyncResultCallback.start(); ollamaAsyncResultCallback.start();
return ollamaAsyncResultCallback; return ollamaAsyncResultCallback;
} }

View File

@ -11,6 +11,7 @@ import java.net.http.HttpClient;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.Queue; import java.util.Queue;
@ -24,17 +25,23 @@ public class OllamaAsyncResultCallback extends Thread {
private boolean isDone; private boolean isDone;
private boolean succeeded; private boolean succeeded;
private long requestTimeoutSeconds;
private int httpStatusCode; private int httpStatusCode;
private long responseTime = 0; private long responseTime = 0;
public OllamaAsyncResultCallback( public OllamaAsyncResultCallback(
HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) { HttpClient client,
URI uri,
OllamaRequestModel ollamaRequestModel,
long requestTimeoutSeconds) {
this.client = client; this.client = client;
this.ollamaRequestModel = ollamaRequestModel; this.ollamaRequestModel = ollamaRequestModel;
this.uri = uri; this.uri = uri;
this.isDone = false; this.isDone = false;
this.result = ""; this.result = "";
this.queue.add(""); this.queue.add("");
this.requestTimeoutSeconds = requestTimeoutSeconds;
} }
@Override @Override
@ -47,6 +54,7 @@ public class OllamaAsyncResultCallback extends Thread {
HttpRequest.BodyPublishers.ofString( HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.build(); .build();
HttpResponse<InputStream> response = HttpResponse<InputStream> response =
client.send(request, HttpResponse.BodyHandlers.ofInputStream()); client.send(request, HttpResponse.BodyHandlers.ofInputStream());
@ -140,4 +148,8 @@ public class OllamaAsyncResultCallback extends Thread {
public long getResponseTime() { public long getResponseTime() {
return responseTime; return responseTime;
} }
public void setRequestTimeoutSeconds(long requestTimeoutSeconds) {
this.requestTimeoutSeconds = requestTimeoutSeconds;
}
} }

View File

@ -1,34 +1,66 @@
package io.github.amithkoujalgi.ollama4j.integrationtests; package io.github.amithkoujalgi.ollama4j.integrationtests;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertThrows;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.net.ConnectException; import java.net.ConnectException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.net.http.HttpConnectTimeoutException;
import java.util.Properties;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
class TestRealAPIs { class TestRealAPIs {
OllamaAPI ollamaAPI; OllamaAPI ollamaAPI;
private Properties loadProperties() {
Properties properties = new Properties();
try (InputStream input =
getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
if (input == null) {
throw new RuntimeException("Sorry, unable to find test-config.properties");
}
properties.load(input);
return properties;
} catch (IOException e) {
throw new RuntimeException("Error loading properties", e);
}
}
@BeforeEach @BeforeEach
void setUp() { void setUp() {
String ollamaHost = "http://localhost:11434"; Properties properties = loadProperties();
ollamaAPI = new OllamaAPI(ollamaHost); ollamaAPI = new OllamaAPI(properties.getProperty("ollama.api.url"));
} }
@Test @Test
@Order(1)
void testWrongEndpoint() { void testWrongEndpoint() {
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434"); OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
assertThrows(ConnectException.class, ollamaAPI::listModels); assertThrows(ConnectException.class, ollamaAPI::listModels);
} }
@Test @Test
@Order(1)
void testEndpointReachability() {
try {
assertNotNull(ollamaAPI.listModels());
} catch (HttpConnectTimeoutException e) {
fail(e.getMessage());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Test
@Order(2)
void testListModels() { void testListModels() {
testEndpointReachability();
try { try {
assertNotNull(ollamaAPI.listModels()); assertNotNull(ollamaAPI.listModels());
ollamaAPI.listModels().forEach(System.out::println); ollamaAPI.listModels().forEach(System.out::println);
@ -36,4 +68,19 @@ class TestRealAPIs {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
@Test
@Order(2)
void testPullModel() {
testEndpointReachability();
try {
ollamaAPI.pullModel(OllamaModelType.LLAMA2);
boolean found =
ollamaAPI.listModels().stream()
.anyMatch(model -> model.getModelName().equals(OllamaModelType.LLAMA2));
assertTrue(found);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
} }

View File

@ -114,7 +114,7 @@ class TestMockedAPIs {
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
when(ollamaAPI.askAsync(model, prompt)) when(ollamaAPI.askAsync(model, prompt))
.thenReturn(new OllamaAsyncResultCallback(null, null, null)); .thenReturn(new OllamaAsyncResultCallback(null, null, null, 3));
ollamaAPI.askAsync(model, prompt); ollamaAPI.askAsync(model, prompt);
verify(ollamaAPI, times(1)).askAsync(model, prompt); verify(ollamaAPI, times(1)).askAsync(model, prompt);
} }

View File

@ -0,0 +1 @@
ollama.api.url=http://192.168.29.223:11434