mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 03:47:13 +02:00
Add integration tests and enhance test configurations
Introduced integration tests for various API functionalities, ensuring comprehensive coverage. Updated test dependencies in `pom.xml` and added handling for unknown JSON properties in the `Model` class. Also included configuration to support running unit and integration tests in the CI workflow.
This commit is contained in:
parent
7ef859bba5
commit
e7f58d4e0d
3
.github/workflows/run-tests.yml
vendored
3
.github/workflows/run-tests.yml
vendored
@ -27,3 +27,6 @@ jobs:
|
||||
|
||||
- name: Run unit tests
|
||||
run: mvn clean test -Punit-tests
|
||||
|
||||
- name: Run integration tests
|
||||
run: mvn clean verify -Pintegration-tests
|
7
pom.xml
7
pom.xml
@ -216,6 +216,13 @@
|
||||
<version>20240205</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.testcontainers</groupId>
|
||||
<artifactId>ollama</artifactId>
|
||||
<version>1.20.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<distributionManagement>
|
||||
|
@ -323,26 +323,58 @@ public class OllamaAPI {
|
||||
public void pullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
String url = this.host + "/api/pull";
|
||||
String jsonData = new ModelRequest(modelName).toString();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json").header("Content-type", "application/json").build();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url))
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
||||
.header("Accept", "application/json")
|
||||
.header("Content-type", "application/json")
|
||||
.build();
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
int statusCode = response.statusCode();
|
||||
InputStream responseBodyStream = response.body();
|
||||
String responseString = "";
|
||||
boolean success = false; // Flag to check the pull success.
|
||||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
|
||||
if (verbose) {
|
||||
logger.info(modelPullResponse.getStatus());
|
||||
if (modelPullResponse != null && modelPullResponse.getStatus() != null) {
|
||||
if (verbose) {
|
||||
logger.info(modelName + ": " + modelPullResponse.getStatus());
|
||||
}
|
||||
// Check if status is "success" and set success flag to true.
|
||||
if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) {
|
||||
success = true;
|
||||
}
|
||||
} else {
|
||||
logger.error("Received null or invalid status for model pull.");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!success) {
|
||||
logger.error("Model pull failed or returned invalid status.");
|
||||
throw new OllamaBaseException("Model pull failed or returned invalid status.");
|
||||
}
|
||||
if (statusCode != 200) {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException {
|
||||
String url = this.host + "/api/version";
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
|
||||
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
String responseString = response.body();
|
||||
if (statusCode == 200) {
|
||||
return Utils.getObjectMapper().readValue(responseString, OllamaVersion.class).getVersion();
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Pulls a model using the specified Ollama library model tag.
|
||||
* The model is identified by a name and a tag, which are combined into a single identifier
|
||||
|
@ -2,12 +2,14 @@ package io.github.ollama4j.models.response;
|
||||
|
||||
import java.time.OffsetDateTime;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import io.github.ollama4j.utils.Utils;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
public class Model {
|
||||
|
||||
private String name;
|
||||
|
@ -0,0 +1,10 @@
|
||||
package io.github.ollama4j.models.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class OllamaVersion {
|
||||
private String version;
|
||||
}
|
@ -0,0 +1,239 @@
|
||||
package io.github.ollama4j.integrationtests;
|
||||
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
||||
import io.github.ollama4j.models.chat.OllamaChatResult;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
|
||||
import io.github.ollama4j.models.response.LibraryModel;
|
||||
import io.github.ollama4j.models.response.Model;
|
||||
import io.github.ollama4j.models.response.ModelDetail;
|
||||
import io.github.ollama4j.models.response.OllamaResult;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Order;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.testcontainers.ollama.OllamaContainer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.ConnectException;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@SuppressWarnings("HttpUrlsUsage")
|
||||
public class TestAPIsTest {
|
||||
|
||||
private static OllamaContainer ollama;
|
||||
private static OllamaAPI api;
|
||||
|
||||
@BeforeAll
|
||||
public static void setUp() {
|
||||
String version = "0.5.13";
|
||||
int internalPort = 11434;
|
||||
int mappedPort = 11435;
|
||||
ollama = new OllamaContainer("ollama/ollama:" + version);
|
||||
ollama.addExposedPort(internalPort);
|
||||
List<String> portBindings = new ArrayList<>();
|
||||
portBindings.add(mappedPort + ":" + internalPort);
|
||||
ollama.setPortBindings(portBindings);
|
||||
ollama.start();
|
||||
api = new OllamaAPI("http://" + ollama.getHost() + ":" + ollama.getMappedPort(internalPort));
|
||||
api.setRequestTimeoutSeconds(60);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
void testWrongEndpoint() {
|
||||
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
|
||||
assertThrows(ConnectException.class, ollamaAPI::listModels);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
public void testVersionAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
String expectedVersion = ollama.getDockerImageName().split(":")[1];
|
||||
String actualVersion = api.getVersion();
|
||||
assertEquals(expectedVersion, actualVersion, "Version should match the Docker image version");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
public void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
// Fetch the list of models
|
||||
List<Model> models = api.listModels();
|
||||
// Assert that the models list is not null
|
||||
assertNotNull(models, "Models should not be null");
|
||||
// Assert that models list is either empty or contains more than 0 models
|
||||
assertTrue(models.size() >= 0, "Models list should be empty or contain elements");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
void testListModelsFromLibrary() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
List<LibraryModel> models = api.listModelsFromLibrary();
|
||||
assertNotNull(models);
|
||||
assertFalse(models.isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
public void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
api.pullModel("all-minilm");
|
||||
List<Model> models = api.listModels();
|
||||
assertNotNull(models, "Models should not be null");
|
||||
assertFalse(models.isEmpty(), "Models list should contain elements");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(4)
|
||||
void testListModelDetails() throws IOException, OllamaBaseException, URISyntaxException, InterruptedException {
|
||||
String embeddingModelMinilm = "all-minilm";
|
||||
api.pullModel(embeddingModelMinilm);
|
||||
ModelDetail modelDetails = api.getModelDetails("all-minilm");
|
||||
assertNotNull(modelDetails);
|
||||
assertTrue(modelDetails.getModelFile().contains(embeddingModelMinilm));
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(5)
|
||||
public void testGenerateEmbeddings() throws Exception {
|
||||
String embeddingModelMinilm = "all-minilm";
|
||||
api.pullModel(embeddingModelMinilm);
|
||||
OllamaEmbedResponseModel embeddings = api.embed(embeddingModelMinilm, Arrays.asList("Why is the sky blue?", "Why is the grass green?"));
|
||||
assertNotNull(embeddings, "Embeddings should not be null");
|
||||
assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(6)
|
||||
void testAskModelWithDefaultOptions() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
String chatModel = "qwen2.5:0.5b";
|
||||
api.pullModel(chatModel);
|
||||
OllamaResult result =
|
||||
api.generate(
|
||||
chatModel,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?",
|
||||
false,
|
||||
new OptionsBuilder().build());
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(7)
|
||||
void testAskModelWithDefaultOptionsStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
String chatModel = "qwen2.5:0.5b";
|
||||
api.pullModel(chatModel);
|
||||
StringBuffer sb = new StringBuffer();
|
||||
OllamaResult result = api.generate(chatModel,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?",
|
||||
false,
|
||||
new OptionsBuilder().build(), (s) -> {
|
||||
System.out.println(s);
|
||||
String substring = s.substring(sb.toString().length(), s.length());
|
||||
System.out.println(substring);
|
||||
sb.append(substring);
|
||||
});
|
||||
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
assertEquals(sb.toString().trim(), result.getResponse().trim());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(8)
|
||||
void testAskModelWithOptions() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
String chatModel = "qwen2.5:0.5b-instruct";
|
||||
api.pullModel(chatModel);
|
||||
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].")
|
||||
.build();
|
||||
requestModel = builder.withMessages(requestModel.getMessages())
|
||||
.withMessage(OllamaChatMessageRole.USER, "Give me a cool name")
|
||||
.withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build();
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertFalse(chatResult.getResponseModel().getMessage().getContent().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(9)
|
||||
void testChatWithSystemPrompt() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
String chatModel = "llama3.2:1b";
|
||||
api.pullModel(chatModel);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
|
||||
"You are a silent bot that only says 'Ssh'. Do not say anything else under any circumstances!")
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"What's something that's brown and sticky?")
|
||||
.withOptions(new OptionsBuilder().setTemperature(0.8f).build())
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
|
||||
assertTrue(chatResult.getResponseModel().getMessage().getContent().contains("Ssh"));
|
||||
assertEquals(3, chatResult.getChatHistory().size());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(10)
|
||||
public void testChat() throws Exception {
|
||||
String chatModel = "qwen2.5:0.5b";
|
||||
api.pullModel(chatModel);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel);
|
||||
|
||||
// Create the initial user question
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
|
||||
.build();
|
||||
|
||||
// Start conversation with model
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
|
||||
assertTrue(
|
||||
chatResult.getChatHistory().stream()
|
||||
.anyMatch(chat -> chat.getContent().contains("Paris")),
|
||||
"Expected chat history to contain 'Paris'"
|
||||
);
|
||||
|
||||
// Create the next user question: second largest city
|
||||
requestModel = builder.withMessages(chatResult.getChatHistory())
|
||||
.withMessage(OllamaChatMessageRole.USER, "And what is its official language?")
|
||||
.build();
|
||||
|
||||
// Continue conversation with model
|
||||
chatResult = api.chat(requestModel);
|
||||
|
||||
assertTrue(
|
||||
chatResult.getChatHistory().stream()
|
||||
.anyMatch(chat -> chat.getContent().contains("French")),
|
||||
"Expected chat history to contain 'French'"
|
||||
);
|
||||
|
||||
// Create the next user question: the third question
|
||||
requestModel = builder.withMessages(chatResult.getChatHistory())
|
||||
.withMessage(OllamaChatMessageRole.USER, "What is the largest river in France?")
|
||||
.build();
|
||||
|
||||
// Continue conversation with the model for the third question
|
||||
chatResult = api.chat(requestModel);
|
||||
|
||||
// verify the result
|
||||
assertNotNull(chatResult, "Chat result should not be null");
|
||||
assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should contain more than two messages");
|
||||
assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent().contains("river"), "Response should be related to river");
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user