Add tool deregistration and update integration tests

Introduces a method to deregister all tools in OllamaAPI and ToolRegistry. Updates integration tests to use new models, refactors prompts and assertions, and removes the TimeOfDay class. The WithAuth test is now fully commented out.
This commit is contained in:
amithkoujalgi 2025-08-30 12:01:32 +05:30
parent 4a69df4476
commit 3d0b3eeb7f
No known key found for this signature in database
GPG Key ID: E29A37746AF94B70
4 changed files with 465 additions and 325 deletions

View File

@ -1183,6 +1183,17 @@ public class OllamaAPI {
} }
} }
/**
* Deregisters all tools from the tool registry.
* This method removes all registered tools, effectively clearing the registry.
*/
public void deregisterTools() {
toolRegistry.clear();
if (this.verbose) {
logger.debug("All tools have been deregistered.");
}
}
/** /**
* Registers tools based on the annotations found on the methods of the caller's * Registers tools based on the annotations found on the methods of the caller's
* class and its providers. * class and its providers.

View File

@ -19,4 +19,11 @@ public class ToolRegistry {
public Collection<Tools.ToolSpecification> getRegisteredSpecs() { public Collection<Tools.ToolSpecification> getRegisteredSpecs() {
return tools.values(); return tools.values();
} }
/**
* Removes all registered tools from the registry.
*/
public void clear() {
tools.clear();
}
} }

View File

@ -1,6 +1,5 @@
package io.github.ollama4j.integrationtests; package io.github.ollama4j.integrationtests;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.exceptions.ToolInvocationException; import io.github.ollama4j.exceptions.ToolInvocationException;
@ -16,9 +15,6 @@ import io.github.ollama4j.tools.ToolFunction;
import io.github.ollama4j.tools.Tools; import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.tools.annotations.OllamaToolService; import io.github.ollama4j.tools.annotations.OllamaToolService;
import io.github.ollama4j.utils.OptionsBuilder; import io.github.ollama4j.utils.OptionsBuilder;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; import org.junit.jupiter.api.MethodOrderer.OrderAnnotation;
import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Order;
@ -49,12 +45,11 @@ public class OllamaAPIIntegrationTest {
private static final String EMBEDDING_MODEL_MINILM = "all-minilm"; private static final String EMBEDDING_MODEL_MINILM = "all-minilm";
private static final String CHAT_MODEL_QWEN_SMALL = "qwen2.5:0.5b"; private static final String CHAT_MODEL_QWEN_SMALL = "qwen2.5:0.5b";
private static final String CHAT_MODEL_INSTRUCT = "qwen2.5:0.5b-instruct"; private static final String CHAT_MODEL_INSTRUCT = "qwen2.5:0.5b-instruct";
private static final String CHAT_MODEL_SYSTEM_PROMPT = "llama3.2:1b";
private static final String CHAT_MODEL_LLAMA3 = "llama3";
private static final String IMAGE_MODEL_LLAVA = "llava"; private static final String IMAGE_MODEL_LLAVA = "llava";
private static final String THINKING_MODEL_GPT_OSS = "gpt-oss:20b"; private static final String THINKING_MODEL_GPT_OSS = "gpt-oss:20b";
private static final String THINKING_MODEL_QWEN = "qwen3:0.6b"; // private static final String THINKING_MODEL_QWEN = "qwen3:0.6b";
private static final String GEMMA = "gemma3:1b";
private static final String GEMMA_SMALLEST = "gemma3:270m";
@BeforeAll @BeforeAll
public static void setUp() { public static void setUp() {
try { try {
@ -65,7 +60,8 @@ public class OllamaAPIIntegrationTest {
LOG.info("Using external Ollama host..."); LOG.info("Using external Ollama host...");
api = new OllamaAPI(ollamaHost); api = new OllamaAPI(ollamaHost);
} else { } else {
throw new RuntimeException("USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers Ollama host for the tests now. If you would like to use an external host, please set the env var to USE_EXTERNAL_OLLAMA_HOST=true and set the env var OLLAMA_HOST=http://localhost:11435 or a different host/port."); throw new RuntimeException(
"USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers Ollama host for the tests now. If you would like to use an external host, please set the env var to USE_EXTERNAL_OLLAMA_HOST=true and set the env var OLLAMA_HOST=http://localhost:11435 or a different host/port.");
} }
} catch (Exception e) { } catch (Exception e) {
String ollamaVersion = "0.6.1"; String ollamaVersion = "0.6.1";
@ -104,7 +100,8 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(2) @Order(2)
public void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { public void testListModelsAPI()
throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
// Fetch the list of models // Fetch the list of models
List<Model> models = api.listModels(); List<Model> models = api.listModels();
// Assert that the models list is not null // Assert that the models list is not null
@ -115,7 +112,8 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(2) @Order(2)
void testListModelsFromLibrary() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testListModelsFromLibrary()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
List<LibraryModel> models = api.listModelsFromLibrary(); List<LibraryModel> models = api.listModelsFromLibrary();
assertNotNull(models); assertNotNull(models);
assertFalse(models.isEmpty()); assertFalse(models.isEmpty());
@ -123,7 +121,8 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(3) @Order(3)
public void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { public void testPullModelAPI()
throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
api.pullModel(EMBEDDING_MODEL_MINILM); api.pullModel(EMBEDDING_MODEL_MINILM);
List<Model> models = api.listModels(); List<Model> models = api.listModels();
assertNotNull(models, "Models should not be null"); assertNotNull(models, "Models should not be null");
@ -143,61 +142,52 @@ public class OllamaAPIIntegrationTest {
@Order(5) @Order(5)
public void testEmbeddings() throws Exception { public void testEmbeddings() throws Exception {
api.pullModel(EMBEDDING_MODEL_MINILM); api.pullModel(EMBEDDING_MODEL_MINILM);
OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL_MINILM, Arrays.asList("Why is the sky blue?", "Why is the grass green?")); OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL_MINILM,
Arrays.asList("Why is the sky blue?", "Why is the grass green?"));
assertNotNull(embeddings, "Embeddings should not be null"); assertNotNull(embeddings, "Embeddings should not be null");
assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty"); assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty");
} }
@Test @Test
@Order(6) @Order(6)
void testAskModelWithStructuredOutput() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { void testAskModelWithStructuredOutput()
api.pullModel(CHAT_MODEL_LLAMA3); throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
api.pullModel(GEMMA_SMALLEST);
int timeHour = 6; String prompt = "The sun is shining brightly and is directly overhead at the zenith, casting my shadow over my foot, so it must be noon.";
boolean isNightTime = false;
String prompt = "The Sun is shining, and its " + timeHour + ". Its daytime.";
Map<String, Object> format = new HashMap<>(); Map<String, Object> format = new HashMap<>();
format.put("type", "object"); format.put("type", "object");
format.put("properties", new HashMap<String, Object>() { format.put("properties", new HashMap<String, Object>() {
{ {
put("timeHour", new HashMap<String, Object>() { put("isNoon", new HashMap<String, Object>() {
{
put("type", "integer");
}
});
put("isNightTime", new HashMap<String, Object>() {
{ {
put("type", "boolean"); put("type", "boolean");
} }
}); });
} }
}); });
format.put("required", Arrays.asList("timeHour", "isNightTime")); format.put("required", List.of("isNoon"));
OllamaResult result = api.generate(CHAT_MODEL_LLAMA3, prompt, format); OllamaResult result = api.generate(GEMMA_SMALLEST, prompt, format);
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
assertEquals(timeHour, result.getStructuredResponse().get("timeHour")); assertEquals(true, result.getStructuredResponse().get("isNoon"));
assertEquals(isNightTime, result.getStructuredResponse().get("isNightTime"));
TimeOfDay timeOfDay = result.as(TimeOfDay.class);
assertEquals(timeHour, timeOfDay.getTimeHour());
assertEquals(isNightTime, timeOfDay.isNightTime());
} }
@Test @Test
@Order(6) @Order(6)
void testAskModelWithDefaultOptions() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { void testAskModelWithDefaultOptions()
api.pullModel(CHAT_MODEL_QWEN_SMALL); throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
api.pullModel(GEMMA);
boolean raw = false; boolean raw = false;
boolean thinking = false; boolean thinking = false;
OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, "What is the capital of France? And what's France's connection with Mona Lisa?", raw, thinking, new OptionsBuilder().build()); OllamaResult result = api.generate(GEMMA,
"What is the capital of France? And what's France's connection with Mona Lisa?", raw,
thinking, new OptionsBuilder().build());
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -205,14 +195,17 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(7) @Order(7)
void testAskModelWithDefaultOptionsStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testAskModelWithDefaultOptionsStreamed()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(CHAT_MODEL_QWEN_SMALL); api.pullModel(CHAT_MODEL_QWEN_SMALL);
boolean raw = false; boolean raw = false;
boolean thinking = false; boolean thinking = false;
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, "What is the capital of France? And what's France's connection with Mona Lisa?", raw, thinking, new OptionsBuilder().build(), (s) -> { OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL,
"What is the capital of France? And what's France's connection with Mona Lisa?", raw,
thinking, new OptionsBuilder().build(), (s) -> {
LOG.info(s); LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length()); String substring = s.substring(sb.toString().length());
LOG.info(substring); LOG.info(substring);
sb.append(substring); sb.append(substring);
}); });
@ -225,12 +218,17 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(8) @Order(8)
void testAskModelWithOptions() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { void testAskModelWithOptions() throws OllamaBaseException, IOException, URISyntaxException,
InterruptedException, ToolInvocationException {
api.pullModel(CHAT_MODEL_INSTRUCT); api.pullModel(CHAT_MODEL_INSTRUCT);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_INSTRUCT); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_INSTRUCT);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].").build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
requestModel = builder.withMessages(requestModel.getMessages()).withMessage(OllamaChatMessageRole.USER, "Give me a cool name").withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build(); "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].")
.build();
requestModel = builder.withMessages(requestModel.getMessages())
.withMessage(OllamaChatMessageRole.USER, "Give me a cool name")
.withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
@ -240,10 +238,14 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(9) @Order(9)
void testChatWithSystemPrompt() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { void testChatWithSystemPrompt() throws OllamaBaseException, IOException, URISyntaxException,
api.pullModel(CHAT_MODEL_LLAMA3); InterruptedException, ToolInvocationException {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3); api.pullModel(THINKING_MODEL_GPT_OSS);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!").withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?").withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build(); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_MODEL_GPT_OSS);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!")
.withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?")
.withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
@ -257,44 +259,56 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(10) @Order(10)
public void testChat() throws Exception { public void testChat() throws Exception {
api.pullModel(CHAT_MODEL_LLAMA3); api.pullModel(THINKING_MODEL_GPT_OSS);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_MODEL_GPT_OSS);
// Create the initial user question // Create the initial user question
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.").build(); OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.")
.build();
// Start conversation with model // Start conversation with model
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), "Expected chat history to contain '2'"); assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")),
"Expected chat history to contain '2'");
// Create the next user question: second largest city // Create the next user question: second largest city
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); requestModel = builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build();
// Continue conversation with model // Continue conversation with model
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel);
assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")), "Expected chat history to contain '4'"); assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")),
"Expected chat history to contain '4'");
// Create the next user question: the third question // Create the next user question: the third question
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "What is the largest value between 2, 4 and 6?").build(); requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER,
"What is the largest value between 2, 4 and 6?").build();
// Continue conversation with the model for the third question // Continue conversation with the model for the third question
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel);
// verify the result // verify the result
assertNotNull(chatResult, "Chat result should not be null"); assertNotNull(chatResult, "Chat result should not be null");
assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should contain more than two messages"); assertTrue(chatResult.getChatHistory().size() > 2,
assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent().contains("6"), "Response should contain '6'"); "Chat history should contain more than two messages");
assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent()
.contains("6"), "Response should contain '6'");
} }
@Test @Test
@Order(10) @Order(10)
void testChatWithImageFromURL() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException { void testChatWithImageFromURL() throws OllamaBaseException, IOException, InterruptedException,
URISyntaxException, ToolInvocationException {
api.pullModel(IMAGE_MODEL_LLAVA); api.pullModel(IMAGE_MODEL_LLAVA);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg").build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What's in the picture?", Collections.emptyList(),
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build();
api.registerAnnotatedTools(new OllamaAPIIntegrationTest()); api.registerAnnotatedTools(new OllamaAPIIntegrationTest());
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
@ -303,17 +317,21 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(10) @Order(10)
void testChatWithImageFromFileWithHistoryRecognition() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { void testChatWithImageFromFileWithHistoryRecognition() throws OllamaBaseException, IOException,
URISyntaxException, InterruptedException, ToolInvocationException {
api.pullModel(IMAGE_MODEL_LLAVA); api.pullModel(IMAGE_MODEL_LLAVA);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), List.of(getImageFileFromClasspath("emoji-smile.jpeg"))).build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What's in the picture?", Collections.emptyList(),
List.of(getImageFileFromClasspath("emoji-smile.jpeg"))).build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
builder.reset(); builder.reset();
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "What's the color?").build(); requestModel = builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the color?").build();
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
@ -322,24 +340,67 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(11) @Order(11)
void testChatWithExplicitToolDefinition() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { void testChatWithExplicitToolDefinition() throws OllamaBaseException, IOException, URISyntaxException,
InterruptedException, ToolInvocationException {
api.pullModel(CHAT_MODEL_QWEN_SMALL); api.pullModel(CHAT_MODEL_QWEN_SMALL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL);
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder().functionName("get-employee-details").functionDescription("Get employee details from the database").toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details").description("Get employee details from the database").parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(new Tools.PropsBuilder().withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()).withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()).withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()).build()).required(List.of("employee-name")).build()).build()).build()).toolFunction(arguments -> { final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get employee details from the database")
.parameters(Tools.PromptFuncDefinition.Parameters
.builder().type("object")
.properties(new Tools.PropsBuilder()
.withProperty("employee-name",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The name of the employee, e.g. John Doe")
.required(true)
.build())
.withProperty("employee-address",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
.required(true)
.build())
.withProperty("employee-phone",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The phone number of the employee. Always return a random value. e.g. 9911002233")
.required(true)
.build())
.build())
.required(List.of("employee-name"))
.build())
.build())
.build())
.toolFunction(arguments -> {
// perform DB operations here // perform DB operations here
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone")); return String.format(
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
UUID.randomUUID(), arguments.get("employee-name"),
arguments.get("employee-address"),
arguments.get("employee-phone"));
}).build(); }).build();
api.registerTool(databaseQueryToolSpecification); api.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"Give me the ID of the employee named 'Rahul Kumar'?").build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),
chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size()); assertEquals(1, toolCalls.size());
OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
@ -355,19 +416,24 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(12) @Order(12)
void testChatWithAnnotatedToolsAndSingleParam() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException { void testChatWithAnnotatedToolsAndSingleParam() throws OllamaBaseException, IOException, InterruptedException,
URISyntaxException, ToolInvocationException {
api.pullModel(CHAT_MODEL_QWEN_SMALL); api.pullModel(CHAT_MODEL_QWEN_SMALL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL);
api.registerAnnotatedTools(); api.registerAnnotatedTools();
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Compute the most important constant in the world using 5 digits").build(); OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
"Compute the most important constant in the world using 5 digits")
.build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),
chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size()); assertEquals(1, toolCalls.size());
OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
@ -383,19 +449,24 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(13) @Order(13)
void testChatWithAnnotatedToolsAndMultipleParams() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { void testChatWithAnnotatedToolsAndMultipleParams() throws OllamaBaseException, IOException, URISyntaxException,
api.pullModel(CHAT_MODEL_QWEN_SMALL); InterruptedException, ToolInvocationException {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL); api.pullModel(THINKING_MODEL_GPT_OSS);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_MODEL_GPT_OSS);
api.registerAnnotatedTools(new AnnotatedTool()); api.registerAnnotatedTools(new AnnotatedTool());
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Greet Pedro with a lot of hearts and respond to me, " + "and state how many emojis have been in your greeting").build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"Greet Pedro with a lot of hearts and respond to me, "
+ "and state how many emojis have been in your greeting")
.build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),
chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size()); assertEquals(1, toolCalls.size());
OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
@ -414,20 +485,62 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(14) @Order(14)
void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException,
InterruptedException, ToolInvocationException {
api.pullModel(CHAT_MODEL_QWEN_SMALL); api.pullModel(CHAT_MODEL_QWEN_SMALL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL);
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder().functionName("get-employee-details").functionDescription("Get employee details from the database").toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details").description("Get employee details from the database").parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(new Tools.PropsBuilder().withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()).withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()).withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()).build()).required(List.of("employee-name")).build()).build()).build()).toolFunction(new ToolFunction() { final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get employee details from the database")
.parameters(Tools.PromptFuncDefinition.Parameters
.builder().type("object")
.properties(new Tools.PropsBuilder()
.withProperty("employee-name",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The name of the employee, e.g. John Doe")
.required(true)
.build())
.withProperty("employee-address",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
.required(true)
.build())
.withProperty("employee-phone",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The phone number of the employee. Always return a random value. e.g. 9911002233")
.required(true)
.build())
.build())
.required(List.of("employee-name"))
.build())
.build())
.build())
.toolFunction(new ToolFunction() {
@Override @Override
public Object apply(Map<String, Object> arguments) { public Object apply(Map<String, Object> arguments) {
// perform DB operations here // perform DB operations here
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone")); return String.format(
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
UUID.randomUUID(), arguments.get("employee-name"),
arguments.get("employee-address"),
arguments.get("employee-phone"));
} }
}).build(); }).build();
api.registerTool(databaseQueryToolSpecification); api.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"Give me the ID of the employee named 'Rahul Kumar'?").build();
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
@ -446,10 +559,15 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(15) @Order(15)
void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException,
api.pullModel(THINKING_MODEL_QWEN); ToolInvocationException {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_MODEL_QWEN); api.deregisterTools();
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?").build(); api.pullModel(GEMMA_SMALLEST);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GEMMA_SMALLEST);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
requestModel.setThink(false);
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> { OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
@ -466,11 +584,13 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(15) @Order(15)
void testChatWithThinkingAndStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { void testChatWithThinkingAndStream() throws OllamaBaseException, IOException, URISyntaxException,
api.pullModel(THINKING_MODEL_QWEN); InterruptedException, ToolInvocationException {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_MODEL_QWEN); api.pullModel(THINKING_MODEL_GPT_OSS);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_MODEL_GPT_OSS);
OllamaChatRequest requestModel = builder OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?") .withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.withThinking(true) .withThinking(true)
.withKeepAlive("0m") .withKeepAlive("0m")
.build(); .build();
@ -485,15 +605,19 @@ public class OllamaAPIIntegrationTest {
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent()); assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getThinking() + chatResult.getResponseModel().getMessage().getContent()); assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getThinking()
+ chatResult.getResponseModel().getMessage().getContent());
} }
@Test @Test
@Order(17) @Order(17)
void testAskModelWithOptionsAndImageURLs() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testAskModelWithOptionsAndImageURLs()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(IMAGE_MODEL_LLAVA); api.pullModel(IMAGE_MODEL_LLAVA);
OllamaResult result = api.generateWithImageURLs(IMAGE_MODEL_LLAVA, "What is in this image?", List.of("https://i.pinimg.com/736x/f9/4e/cb/f94ecba040696a3a20b484d2e15159ec.jpg"), new OptionsBuilder().build()); OllamaResult result = api.generateWithImageURLs(IMAGE_MODEL_LLAVA, "What is in this image?",
List.of("https://i.pinimg.com/736x/f9/4e/cb/f94ecba040696a3a20b484d2e15159ec.jpg"),
new OptionsBuilder().build());
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -501,11 +625,13 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(18) @Order(18)
void testAskModelWithOptionsAndImageFiles() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testAskModelWithOptionsAndImageFiles()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(IMAGE_MODEL_LLAVA); api.pullModel(IMAGE_MODEL_LLAVA);
File imageFile = getImageFileFromClasspath("emoji-smile.jpeg"); File imageFile = getImageFileFromClasspath("emoji-smile.jpeg");
try { try {
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile), new OptionsBuilder().build()); OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?",
List.of(imageFile), new OptionsBuilder().build());
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -516,14 +642,16 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(20) @Order(20)
void testAskModelWithOptionsAndImageFilesStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testAskModelWithOptionsAndImageFilesStreamed()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(IMAGE_MODEL_LLAVA); api.pullModel(IMAGE_MODEL_LLAVA);
File imageFile = getImageFileFromClasspath("emoji-smile.jpeg"); File imageFile = getImageFileFromClasspath("emoji-smile.jpeg");
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> { OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?",
List.of(imageFile), new OptionsBuilder().build(), (s) -> {
LOG.info(s); LOG.info(s);
String substring = s.substring(sb.toString().length()); String substring = s.substring(sb.toString().length());
LOG.info(substring); LOG.info(substring);
@ -537,13 +665,15 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(20) @Order(20)
void testGenerateWithThinking() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testGenerateWithThinking()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(THINKING_MODEL_GPT_OSS); api.pullModel(THINKING_MODEL_GPT_OSS);
boolean raw = false; boolean raw = false;
boolean thinking = true; boolean thinking = true;
OllamaResult result = api.generate(THINKING_MODEL_GPT_OSS, "Who are you?", raw, thinking, new OptionsBuilder().build(), null); OllamaResult result = api.generate(THINKING_MODEL_GPT_OSS, "Who are you?", raw, thinking,
new OptionsBuilder().build(), null);
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -553,14 +683,16 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(20) @Order(20)
void testGenerateWithThinkingAndStreamHandler() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testGenerateWithThinkingAndStreamHandler()
api.pullModel(THINKING_MODEL_QWEN); throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(THINKING_MODEL_GPT_OSS);
boolean raw = false; boolean raw = false;
boolean thinking = true; boolean thinking = true;
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaResult result = api.generate(THINKING_MODEL_QWEN, "Who are you?", raw, thinking, new OptionsBuilder().build(), (s) -> { OllamaResult result = api.generate(THINKING_MODEL_GPT_OSS, "Who are you?", raw, thinking,
new OptionsBuilder().build(), (s) -> {
LOG.info(s); LOG.info(s);
String substring = s.substring(sb.toString().length()); String substring = s.substring(sb.toString().length());
sb.append(substring); sb.append(substring);
@ -578,13 +710,3 @@ public class OllamaAPIIntegrationTest {
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
} }
} }
@Data
@AllArgsConstructor
@NoArgsConstructor
class TimeOfDay {
@JsonProperty("timeHour")
private int timeHour;
@JsonProperty("isNightTime")
private boolean nightTime;
}

View File

@ -1,194 +1,194 @@
package io.github.ollama4j.integrationtests; //package io.github.ollama4j.integrationtests;
//
import io.github.ollama4j.OllamaAPI; //import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.OllamaBaseException; //import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.response.OllamaResult; //import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.samples.AnnotatedTool; //import io.github.ollama4j.samples.AnnotatedTool;
import io.github.ollama4j.tools.annotations.OllamaToolService; //import io.github.ollama4j.tools.annotations.OllamaToolService;
import org.junit.jupiter.api.BeforeAll; //import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; //import org.junit.jupiter.api.MethodOrderer.OrderAnnotation;
import org.junit.jupiter.api.Order; //import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test; //import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder; //import org.junit.jupiter.api.TestMethodOrder;
import org.slf4j.Logger; //import org.slf4j.Logger;
import org.slf4j.LoggerFactory; //import org.slf4j.LoggerFactory;
import org.testcontainers.containers.GenericContainer; //import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.NginxContainer; //import org.testcontainers.containers.NginxContainer;
import org.testcontainers.containers.wait.strategy.Wait; //import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.ollama.OllamaContainer; //import org.testcontainers.ollama.OllamaContainer;
import org.testcontainers.utility.DockerImageName; //import org.testcontainers.utility.DockerImageName;
import org.testcontainers.utility.MountableFile; //import org.testcontainers.utility.MountableFile;
//
import java.io.File; //import java.io.File;
import java.io.FileWriter; //import java.io.FileWriter;
import java.io.IOException; //import java.io.IOException;
import java.net.URISyntaxException; //import java.net.URISyntaxException;
import java.time.Duration; //import java.time.Duration;
import java.util.Arrays; //import java.util.Arrays;
import java.util.HashMap; //import java.util.HashMap;
import java.util.Map; //import java.util.Map;
//
import static org.junit.jupiter.api.Assertions.*; //import static org.junit.jupiter.api.Assertions.*;
//
@OllamaToolService(providers = {AnnotatedTool.class}) //@OllamaToolService(providers = {AnnotatedTool.class})
@TestMethodOrder(OrderAnnotation.class) //@TestMethodOrder(OrderAnnotation.class)
@SuppressWarnings({"HttpUrlsUsage", "SpellCheckingInspection", "resource", "ResultOfMethodCallIgnored"}) //@SuppressWarnings({"HttpUrlsUsage", "SpellCheckingInspection", "resource", "ResultOfMethodCallIgnored"})
public class WithAuth { //public class WithAuth {
//
private static final Logger LOG = LoggerFactory.getLogger(WithAuth.class); // private static final Logger LOG = LoggerFactory.getLogger(WithAuth.class);
private static final int NGINX_PORT = 80; // private static final int NGINX_PORT = 80;
private static final int OLLAMA_INTERNAL_PORT = 11434; // private static final int OLLAMA_INTERNAL_PORT = 11434;
private static final String OLLAMA_VERSION = "0.6.1"; // private static final String OLLAMA_VERSION = "0.6.1";
private static final String NGINX_VERSION = "nginx:1.23.4-alpine"; // private static final String NGINX_VERSION = "nginx:1.23.4-alpine";
private static final String BEARER_AUTH_TOKEN = "secret-token"; // private static final String BEARER_AUTH_TOKEN = "secret-token";
private static final String CHAT_MODEL_LLAMA3 = "llama3"; // private static final String CHAT_MODEL_LLAMA3 = "llama3";
//
//
private static OllamaContainer ollama; // private static OllamaContainer ollama;
private static GenericContainer<?> nginx; // private static GenericContainer<?> nginx;
private static OllamaAPI api; // private static OllamaAPI api;
//
@BeforeAll // @BeforeAll
public static void setUp() { // public static void setUp() {
ollama = createOllamaContainer(); // ollama = createOllamaContainer();
ollama.start(); // ollama.start();
//
nginx = createNginxContainer(ollama.getMappedPort(OLLAMA_INTERNAL_PORT)); // nginx = createNginxContainer(ollama.getMappedPort(OLLAMA_INTERNAL_PORT));
nginx.start(); // nginx.start();
//
LOG.info("Using Testcontainer Ollama host..."); // LOG.info("Using Testcontainer Ollama host...");
//
api = new OllamaAPI("http://" + nginx.getHost() + ":" + nginx.getMappedPort(NGINX_PORT)); // api = new OllamaAPI("http://" + nginx.getHost() + ":" + nginx.getMappedPort(NGINX_PORT));
api.setRequestTimeoutSeconds(120); // api.setRequestTimeoutSeconds(120);
api.setVerbose(true); // api.setVerbose(true);
api.setNumberOfRetriesForModelPull(3); // api.setNumberOfRetriesForModelPull(3);
//
String ollamaUrl = "http://" + ollama.getHost() + ":" + ollama.getMappedPort(OLLAMA_INTERNAL_PORT); // String ollamaUrl = "http://" + ollama.getHost() + ":" + ollama.getMappedPort(OLLAMA_INTERNAL_PORT);
String nginxUrl = "http://" + nginx.getHost() + ":" + nginx.getMappedPort(NGINX_PORT); // String nginxUrl = "http://" + nginx.getHost() + ":" + nginx.getMappedPort(NGINX_PORT);
LOG.info( // LOG.info(
"The Ollama service is now accessible via the Nginx proxy with bearer-auth authentication mode.\n" + // "The Ollama service is now accessible via the Nginx proxy with bearer-auth authentication mode.\n" +
"→ Ollama URL: {}\n" + // "→ Ollama URL: {}\n" +
"→ Proxy URL: {}", // "→ Proxy URL: {}",
ollamaUrl, nginxUrl // ollamaUrl, nginxUrl
); // );
LOG.info("OllamaAPI initialized with bearer auth token: {}", BEARER_AUTH_TOKEN); // LOG.info("OllamaAPI initialized with bearer auth token: {}", BEARER_AUTH_TOKEN);
} // }
//
private static OllamaContainer createOllamaContainer() { // private static OllamaContainer createOllamaContainer() {
return new OllamaContainer("ollama/ollama:" + OLLAMA_VERSION).withExposedPorts(OLLAMA_INTERNAL_PORT); // return new OllamaContainer("ollama/ollama:" + OLLAMA_VERSION).withExposedPorts(OLLAMA_INTERNAL_PORT);
} // }
//
private static String generateNginxConfig(int ollamaPort) { // private static String generateNginxConfig(int ollamaPort) {
return String.format("events {}\n" + // return String.format("events {}\n" +
"\n" + // "\n" +
"http {\n" + // "http {\n" +
" server {\n" + // " server {\n" +
" listen 80;\n" + // " listen 80;\n" +
"\n" + // "\n" +
" location / {\n" + // " location / {\n" +
" set $auth_header $http_authorization;\n" + // " set $auth_header $http_authorization;\n" +
"\n" + // "\n" +
" if ($auth_header != \"Bearer secret-token\") {\n" + // " if ($auth_header != \"Bearer secret-token\") {\n" +
" return 401;\n" + // " return 401;\n" +
" }\n" + // " }\n" +
"\n" + // "\n" +
" proxy_pass http://host.docker.internal:%s/;\n" + // " proxy_pass http://host.docker.internal:%s/;\n" +
" proxy_set_header Host $host;\n" + // " proxy_set_header Host $host;\n" +
" proxy_set_header X-Real-IP $remote_addr;\n" + // " proxy_set_header X-Real-IP $remote_addr;\n" +
" proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;\n" + // " proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;\n" +
" proxy_set_header X-Forwarded-Proto $scheme;\n" + // " proxy_set_header X-Forwarded-Proto $scheme;\n" +
" }\n" + // " }\n" +
" }\n" + // " }\n" +
"}\n", ollamaPort); // "}\n", ollamaPort);
} // }
//
public static GenericContainer<?> createNginxContainer(int ollamaPort) { // public static GenericContainer<?> createNginxContainer(int ollamaPort) {
File nginxConf; // File nginxConf;
try { // try {
File tempDir = new File(System.getProperty("java.io.tmpdir"), "nginx-auth"); // File tempDir = new File(System.getProperty("java.io.tmpdir"), "nginx-auth");
if (!tempDir.exists()) tempDir.mkdirs(); // if (!tempDir.exists()) tempDir.mkdirs();
//
nginxConf = new File(tempDir, "nginx.conf"); // nginxConf = new File(tempDir, "nginx.conf");
try (FileWriter writer = new FileWriter(nginxConf)) { // try (FileWriter writer = new FileWriter(nginxConf)) {
writer.write(generateNginxConfig(ollamaPort)); // writer.write(generateNginxConfig(ollamaPort));
} // }
//
return new NginxContainer<>(DockerImageName.parse(NGINX_VERSION)) // return new NginxContainer<>(DockerImageName.parse(NGINX_VERSION))
.withExposedPorts(NGINX_PORT) // .withExposedPorts(NGINX_PORT)
.withCopyFileToContainer( // .withCopyFileToContainer(
MountableFile.forHostPath(nginxConf.getAbsolutePath()), // MountableFile.forHostPath(nginxConf.getAbsolutePath()),
"/etc/nginx/nginx.conf" // "/etc/nginx/nginx.conf"
) // )
.withExtraHost("host.docker.internal", "host-gateway") // .withExtraHost("host.docker.internal", "host-gateway")
.waitingFor( // .waitingFor(
Wait.forHttp("/") // Wait.forHttp("/")
.forStatusCode(401) // .forStatusCode(401)
.withStartupTimeout(Duration.ofSeconds(30)) // .withStartupTimeout(Duration.ofSeconds(30))
); // );
} catch (IOException e) { // } catch (IOException e) {
throw new RuntimeException("Failed to create nginx.conf", e); // throw new RuntimeException("Failed to create nginx.conf", e);
} // }
} // }
//
@Test // @Test
@Order(1) // @Order(1)
void testOllamaBehindProxy() throws InterruptedException { // void testOllamaBehindProxy() throws InterruptedException {
api.setBearerAuth(BEARER_AUTH_TOKEN); // api.setBearerAuth(BEARER_AUTH_TOKEN);
assertTrue(api.ping(), "Expected OllamaAPI to successfully ping through NGINX with valid auth token."); // assertTrue(api.ping(), "Expected OllamaAPI to successfully ping through NGINX with valid auth token.");
} // }
//
@Test // @Test
@Order(1) // @Order(1)
void testWithWrongToken() throws InterruptedException { // void testWithWrongToken() throws InterruptedException {
api.setBearerAuth("wrong-token"); // api.setBearerAuth("wrong-token");
assertFalse(api.ping(), "Expected OllamaAPI ping to fail through NGINX with an invalid auth token."); // assertFalse(api.ping(), "Expected OllamaAPI ping to fail through NGINX with an invalid auth token.");
} // }
//
@Test // @Test
@Order(2) // @Order(2)
void testAskModelWithStructuredOutput() // void testAskModelWithStructuredOutput()
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { // throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
api.setBearerAuth(BEARER_AUTH_TOKEN); // api.setBearerAuth(BEARER_AUTH_TOKEN);
//
api.pullModel(CHAT_MODEL_LLAMA3); // api.pullModel(CHAT_MODEL_LLAMA3);
//
int timeHour = 6; // int timeHour = 6;
boolean isNightTime = false; // boolean isNightTime = false;
//
String prompt = "The Sun is shining, and its " + timeHour + ". Its daytime."; // String prompt = "The Sun is shining, and its " + timeHour + ". Its daytime.";
//
Map<String, Object> format = new HashMap<>(); // Map<String, Object> format = new HashMap<>();
format.put("type", "object"); // format.put("type", "object");
format.put("properties", new HashMap<String, Object>() { // format.put("properties", new HashMap<String, Object>() {
{ // {
put("timeHour", new HashMap<String, Object>() { // put("timeHour", new HashMap<String, Object>() {
{ // {
put("type", "integer"); // put("type", "integer");
} // }
}); // });
put("isNightTime", new HashMap<String, Object>() { // put("isNightTime", new HashMap<String, Object>() {
{ // {
put("type", "boolean"); // put("type", "boolean");
} // }
}); // });
} // }
}); // });
format.put("required", Arrays.asList("timeHour", "isNightTime")); // format.put("required", Arrays.asList("timeHour", "isNightTime"));
//
OllamaResult result = api.generate(CHAT_MODEL_LLAMA3, prompt, format); // OllamaResult result = api.generate(CHAT_MODEL_LLAMA3, prompt, format);
//
assertNotNull(result); // assertNotNull(result);
assertNotNull(result.getResponse()); // assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); // assertFalse(result.getResponse().isEmpty());
//
assertEquals(timeHour, // assertEquals(timeHour,
result.getStructuredResponse().get("timeHour")); // result.getStructuredResponse().get("timeHour"));
assertEquals(isNightTime, // assertEquals(isNightTime,
result.getStructuredResponse().get("isNightTime")); // result.getStructuredResponse().get("isNightTime"));
//
TimeOfDay timeOfDay = result.as(TimeOfDay.class); // TimeOfDay timeOfDay = result.as(TimeOfDay.class);
//
assertEquals(timeHour, timeOfDay.getTimeHour()); // assertEquals(timeHour, timeOfDay.getTimeHour());
assertEquals(isNightTime, timeOfDay.isNightTime()); // assertEquals(isNightTime, timeOfDay.isNightTime());
} // }
} //}