diff --git a/src/test/java/io/github/ollama4j/unittests/TestAuth.java b/src/test/java/io/github/ollama4j/unittests/TestAuth.java new file mode 100644 index 0000000..c1078e4 --- /dev/null +++ b/src/test/java/io/github/ollama4j/unittests/TestAuth.java @@ -0,0 +1,26 @@ +package io.github.ollama4j.unittests; + +import io.github.ollama4j.models.request.BasicAuth; +import io.github.ollama4j.models.request.BearerAuth; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +public class TestAuth { + + @Test + public void testBasicAuthHeaderEncoding() { + BasicAuth auth = new BasicAuth("alice", "s3cr3t"); + String header = auth.getAuthHeaderValue(); + assertTrue(header.startsWith("Basic ")); + // "alice:s3cr3t" base64 is "YWxpY2U6czNjcjN0" + assertEquals("Basic YWxpY2U6czNjcjN0", header); + } + + @Test + public void testBearerAuthHeaderFormat() { + BearerAuth auth = new BearerAuth("abc.def.ghi"); + String header = auth.getAuthHeaderValue(); + assertEquals("Bearer abc.def.ghi", header); + } +} diff --git a/src/test/java/io/github/ollama4j/unittests/TestOptionsAndUtils.java b/src/test/java/io/github/ollama4j/unittests/TestOptionsAndUtils.java new file mode 100644 index 0000000..af86c56 --- /dev/null +++ b/src/test/java/io/github/ollama4j/unittests/TestOptionsAndUtils.java @@ -0,0 +1,92 @@ +package io.github.ollama4j.unittests; + +import io.github.ollama4j.utils.Options; +import io.github.ollama4j.utils.OptionsBuilder; +import io.github.ollama4j.utils.PromptBuilder; +import io.github.ollama4j.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +public class TestOptionsAndUtils { + + @Test + public void testOptionsBuilderSetsValues() { + Options options = new OptionsBuilder() + .setMirostat(1) + .setMirostatEta(0.2f) + .setMirostatTau(4.5f) + .setNumCtx(1024) + .setNumGqa(8) + .setNumGpu(2) + .setNumThread(6) + .setRepeatLastN(32) + .setRepeatPenalty(1.2f) + .setTemperature(0.7f) + .setSeed(42) + .setStop("STOP") + .setTfsZ(1.5f) + .setNumPredict(256) + .setTopK(50) + .setTopP(0.95f) + .setMinP(0.05f) + .setCustomOption("custom_param", 123) + .build(); + + Map map = options.getOptionsMap(); + assertEquals(1, map.get("mirostat")); + assertEquals(0.2f, (Float) map.get("mirostat_eta"), 0.0001); + assertEquals(4.5f, (Float) map.get("mirostat_tau"), 0.0001); + assertEquals(1024, map.get("num_ctx")); + assertEquals(8, map.get("num_gqa")); + assertEquals(2, map.get("num_gpu")); + assertEquals(6, map.get("num_thread")); + assertEquals(32, map.get("repeat_last_n")); + assertEquals(1.2f, (Float) map.get("repeat_penalty"), 0.0001); + assertEquals(0.7f, (Float) map.get("temperature"), 0.0001); + assertEquals(42, map.get("seed")); + assertEquals("STOP", map.get("stop")); + assertEquals(1.5f, (Float) map.get("tfs_z"), 0.0001); + assertEquals(256, map.get("num_predict")); + assertEquals(50, map.get("top_k")); + assertEquals(0.95f, (Float) map.get("top_p"), 0.0001); + assertEquals(0.05f, (Float) map.get("min_p"), 0.0001); + assertEquals(123, map.get("custom_param")); + } + + @Test + public void testOptionsBuilderRejectsUnsupportedCustomType() { + OptionsBuilder builder = new OptionsBuilder(); + assertThrows(IllegalArgumentException.class, () -> builder.setCustomOption("bad", new Object())); + } + + @Test + public void testPromptBuilderBuildsExpectedString() { + String prompt = new PromptBuilder() + .add("Hello") + .addLine(", world!") + .addSeparator() + .add("Continue.") + .build(); + + String expected = "Hello, world!\n\n--------------------------------------------------\nContinue."; + assertEquals(expected, prompt); + } + + @Test + public void testUtilsGetObjectMapperSingletonAndModule() { + assertSame(Utils.getObjectMapper(), Utils.getObjectMapper()); + // Basic serialization sanity check with JavaTimeModule registered + assertDoesNotThrow(() -> Utils.getObjectMapper().writeValueAsString(java.time.OffsetDateTime.now())); + } + + @Test + public void testGetFileFromClasspath() { + File f = Utils.getFileFromClasspath("test-config.properties"); + assertTrue(f.exists()); + assertTrue(f.getName().contains("test-config.properties")); + } +} diff --git a/src/test/java/io/github/ollama4j/unittests/TestToolRegistry.java b/src/test/java/io/github/ollama4j/unittests/TestToolRegistry.java new file mode 100644 index 0000000..a8b631a --- /dev/null +++ b/src/test/java/io/github/ollama4j/unittests/TestToolRegistry.java @@ -0,0 +1,48 @@ +package io.github.ollama4j.unittests; + +import io.github.ollama4j.tools.ToolFunction; +import io.github.ollama4j.tools.ToolRegistry; +import io.github.ollama4j.tools.Tools; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +public class TestToolRegistry { + + @Test + public void testAddAndGetToolFunction() { + ToolRegistry registry = new ToolRegistry(); + ToolFunction fn = args -> "ok:" + args.get("x"); + + Tools.ToolSpecification spec = Tools.ToolSpecification.builder() + .functionName("test") + .functionDescription("desc") + .toolFunction(fn) + .build(); + + registry.addTool("test", spec); + ToolFunction retrieved = registry.getToolFunction("test"); + assertNotNull(retrieved); + assertEquals("ok:42", retrieved.apply(Map.of("x", 42))); + } + + @Test + public void testGetUnknownReturnsNull() { + ToolRegistry registry = new ToolRegistry(); + assertNull(registry.getToolFunction("nope")); + } + + @Test + public void testClearRemovesAll() { + ToolRegistry registry = new ToolRegistry(); + registry.addTool("a", Tools.ToolSpecification.builder().toolFunction(args -> 1).build()); + registry.addTool("b", Tools.ToolSpecification.builder().toolFunction(args -> 2).build()); + assertFalse(registry.getRegisteredSpecs().isEmpty()); + registry.clear(); + assertTrue(registry.getRegisteredSpecs().isEmpty()); + assertNull(registry.getToolFunction("a")); + assertNull(registry.getToolFunction("b")); + } +} diff --git a/src/test/java/io/github/ollama4j/unittests/TestToolsPromptBuilder.java b/src/test/java/io/github/ollama4j/unittests/TestToolsPromptBuilder.java new file mode 100644 index 0000000..4febf87 --- /dev/null +++ b/src/test/java/io/github/ollama4j/unittests/TestToolsPromptBuilder.java @@ -0,0 +1,64 @@ +package io.github.ollama4j.unittests; + +import com.fasterxml.jackson.core.JsonProcessingException; +import io.github.ollama4j.tools.Tools; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +public class TestToolsPromptBuilder { + + @Test + public void testPromptBuilderIncludesToolsAndPrompt() throws JsonProcessingException { + Tools.PromptFuncDefinition.Property cityProp = Tools.PromptFuncDefinition.Property.builder() + .type("string") + .description("city name") + .required(true) + .build(); + + Tools.PromptFuncDefinition.Property unitsProp = Tools.PromptFuncDefinition.Property.builder() + .type("string") + .description("units") + .enumValues(List.of("metric", "imperial")) + .required(false) + .build(); + + Tools.PromptFuncDefinition.Parameters params = Tools.PromptFuncDefinition.Parameters.builder() + .type("object") + .properties(Map.of("city", cityProp, "units", unitsProp)) + .build(); + + Tools.PromptFuncDefinition.PromptFuncSpec spec = Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name("getWeather") + .description("Get weather for a city") + .parameters(params) + .build(); + + Tools.PromptFuncDefinition def = Tools.PromptFuncDefinition.builder() + .type("function") + .function(spec) + .build(); + + Tools.ToolSpecification toolSpec = Tools.ToolSpecification.builder() + .functionName("getWeather") + .functionDescription("Get weather for a city") + .toolPrompt(def) + .build(); + + Tools.PromptBuilder pb = new Tools.PromptBuilder() + .withToolSpecification(toolSpec) + .withPrompt("Tell me the weather."); + + String built = pb.build(); + assertTrue(built.contains("[AVAILABLE_TOOLS]")); + assertTrue(built.contains("[/AVAILABLE_TOOLS]")); + assertTrue(built.contains("[INST]")); + assertTrue(built.contains("Tell me the weather.")); + assertTrue(built.contains("\"name\":\"getWeather\"")); + assertTrue(built.contains("\"required\":[\"city\"]")); + assertTrue(built.contains("\"enum\":[\"metric\",\"imperial\"]")); + } +}