From e2a88b06c5a7ebbb55af3501d32b5ccdf67d9894 Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Thu, 14 Dec 2023 20:10:18 +0530 Subject: [PATCH] Updated tests --- .../ollama4j/core/OllamaAPI.java | 11 +++- .../models/OllamaAsyncResultCallback.java | 14 ++++- .../integrationtests/TestRealAPIs.java | 57 +++++++++++++++++-- .../ollama4j/unittests/TestMockedAPIs.java | 2 +- src/test/resources/test-config.properties | 1 + 5 files changed, 77 insertions(+), 8 deletions(-) create mode 100644 src/test/resources/test-config.properties diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 98ed1df..6ec5d05 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -13,6 +13,7 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -23,6 +24,7 @@ public class OllamaAPI { private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); private final String host; + private long requestTimeoutSeconds = 3; private boolean verbose = false; /** @@ -61,6 +63,7 @@ public class OllamaAPI { .uri(new URI(url)) .header("Accept", "application/json") .header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .GET() .build(); HttpResponse response = @@ -92,6 +95,7 @@ public class OllamaAPI { .POST(HttpRequest.BodyPublishers.ofString(jsonData)) .header("Accept", "application/json") .header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = @@ -130,6 +134,7 @@ public class OllamaAPI { .uri(URI.create(url)) .header("Accept", "application/json") .header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .POST(HttpRequest.BodyPublishers.ofString(jsonData)) .build(); HttpClient client = HttpClient.newHttpClient(); @@ -160,6 +165,7 @@ public class OllamaAPI { .uri(URI.create(url)) .header("Accept", "application/json") .header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) .build(); HttpClient client = HttpClient.newHttpClient(); @@ -196,6 +202,7 @@ public class OllamaAPI { .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) .header("Accept", "application/json") .header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); @@ -226,6 +233,7 @@ public class OllamaAPI { .uri(URI.create(url)) .header("Accept", "application/json") .header("Content-type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .POST(HttpRequest.BodyPublishers.ofString(jsonData)) .build(); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); @@ -259,6 +267,7 @@ public class OllamaAPI { HttpRequest.BodyPublishers.ofString( Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) .header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .build(); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); @@ -304,7 +313,7 @@ public class OllamaAPI { HttpClient httpClient = HttpClient.newHttpClient(); URI uri = URI.create(this.host + "/api/generate"); OllamaAsyncResultCallback ollamaAsyncResultCallback = - new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel); + new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel, requestTimeoutSeconds); ollamaAsyncResultCallback.start(); return ollamaAsyncResultCallback; } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java index 2f8ea32..ebf5323 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java @@ -11,6 +11,7 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.LinkedList; import java.util.Queue; @@ -24,17 +25,23 @@ public class OllamaAsyncResultCallback extends Thread { private boolean isDone; private boolean succeeded; + private long requestTimeoutSeconds; + private int httpStatusCode; private long responseTime = 0; public OllamaAsyncResultCallback( - HttpClient client, URI uri, OllamaRequestModel ollamaRequestModel) { + HttpClient client, + URI uri, + OllamaRequestModel ollamaRequestModel, + long requestTimeoutSeconds) { this.client = client; this.ollamaRequestModel = ollamaRequestModel; this.uri = uri; this.isDone = false; this.result = ""; this.queue.add(""); + this.requestTimeoutSeconds = requestTimeoutSeconds; } @Override @@ -47,6 +54,7 @@ public class OllamaAsyncResultCallback extends Thread { HttpRequest.BodyPublishers.ofString( Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) .header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)) .build(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); @@ -140,4 +148,8 @@ public class OllamaAsyncResultCallback extends Thread { public long getResponseTime() { return responseTime; } + + public void setRequestTimeoutSeconds(long requestTimeoutSeconds) { + this.requestTimeoutSeconds = requestTimeoutSeconds; + } } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java index 5abed5c..271f980 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java @@ -1,34 +1,66 @@ package io.github.amithkoujalgi.ollama4j.integrationtests; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.*; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; +import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; import java.io.IOException; +import java.io.InputStream; import java.net.ConnectException; import java.net.URISyntaxException; +import java.net.http.HttpConnectTimeoutException; +import java.util.Properties; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Test; class TestRealAPIs { - OllamaAPI ollamaAPI; + private Properties loadProperties() { + Properties properties = new Properties(); + try (InputStream input = + getClass().getClassLoader().getResourceAsStream("test-config.properties")) { + if (input == null) { + throw new RuntimeException("Sorry, unable to find test-config.properties"); + } + properties.load(input); + return properties; + } catch (IOException e) { + throw new RuntimeException("Error loading properties", e); + } + } + @BeforeEach void setUp() { - String ollamaHost = "http://localhost:11434"; - ollamaAPI = new OllamaAPI(ollamaHost); + Properties properties = loadProperties(); + ollamaAPI = new OllamaAPI(properties.getProperty("ollama.api.url")); } @Test + @Order(1) void testWrongEndpoint() { OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434"); assertThrows(ConnectException.class, ollamaAPI::listModels); } @Test + @Order(1) + void testEndpointReachability() { + try { + assertNotNull(ollamaAPI.listModels()); + } catch (HttpConnectTimeoutException e) { + fail(e.getMessage()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + @Order(2) void testListModels() { + testEndpointReachability(); try { assertNotNull(ollamaAPI.listModels()); ollamaAPI.listModels().forEach(System.out::println); @@ -36,4 +68,19 @@ class TestRealAPIs { throw new RuntimeException(e); } } + + @Test + @Order(2) + void testPullModel() { + testEndpointReachability(); + try { + ollamaAPI.pullModel(OllamaModelType.LLAMA2); + boolean found = + ollamaAPI.listModels().stream() + .anyMatch(model -> model.getModelName().equals(OllamaModelType.LLAMA2)); + assertTrue(found); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + throw new RuntimeException(e); + } + } } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java index 35496ec..3b5fafc 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java @@ -114,7 +114,7 @@ class TestMockedAPIs { String model = OllamaModelType.LLAMA2; String prompt = "some prompt text"; when(ollamaAPI.askAsync(model, prompt)) - .thenReturn(new OllamaAsyncResultCallback(null, null, null)); + .thenReturn(new OllamaAsyncResultCallback(null, null, null, 3)); ollamaAPI.askAsync(model, prompt); verify(ollamaAPI, times(1)).askAsync(model, prompt); } diff --git a/src/test/resources/test-config.properties b/src/test/resources/test-config.properties new file mode 100644 index 0000000..6c2a862 --- /dev/null +++ b/src/test/resources/test-config.properties @@ -0,0 +1 @@ +ollama.api.url=http://192.168.29.223:11434 \ No newline at end of file