mirror of
				https://github.com/amithkoujalgi/ollama4j.git
				synced 2025-10-31 16:40:41 +01:00 
			
		
		
		
	Add unit tests for annotations, serialization, and tool functions
- Introduced TestAnnotations.java to validate OllamaToolService and ToolProperty annotations. - Added TestBooleanToJsonFormatFlagSerializer.java to test serialization of Boolean values. - Created TestFileToBase64Serializer.java for testing byte array serialization to Base64. - Implemented TestOllamaChatMessage.java to ensure correct JSON representation of chat messages. - Developed TestOllamaChatMessageRole.java to verify role registration and custom role creation. - Added TestOllamaChatRequestBuilder.java to test message handling and request building. - Created TestOllamaRequestBody.java to validate request body serialization. - Implemented TestOllamaToolsResult.java to ensure correct transformation of tool results. - Added TestOptionsAndUtils.java to test options builder and utility functions. - Created TestReflectionalToolFunction.java to validate method invocation and argument handling. - Implemented TestToolRegistry.java to ensure tool registration and retrieval functionality. - Developed TestToolsPromptBuilder.java to verify prompt builder includes tools and prompts correctly. - Added serialization tests in TestChatRequestSerialization.java and TestEmbedRequestSerialization.java.
This commit is contained in:
		| @@ -0,0 +1,50 @@ | ||||
| package io.github.ollama4j.unittests; | ||||
|  | ||||
| import io.github.ollama4j.tools.annotations.OllamaToolService; | ||||
| import io.github.ollama4j.tools.annotations.ToolProperty; | ||||
| import io.github.ollama4j.tools.annotations.ToolSpec; | ||||
| import org.junit.jupiter.api.Test; | ||||
|  | ||||
| import java.lang.reflect.Method; | ||||
| import java.lang.reflect.Parameter; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
|  | ||||
| class TestAnnotations { | ||||
|  | ||||
|     @OllamaToolService(providers = {SampleProvider.class}) | ||||
|     static class SampleToolService { | ||||
|     } | ||||
|  | ||||
|     static class SampleProvider { | ||||
|         @ToolSpec(name = "sum", desc = "adds two numbers") | ||||
|         public int sum(@ToolProperty(name = "a", desc = "first addend") int a, | ||||
|                        @ToolProperty(name = "b", desc = "second addend", required = false) int b) { | ||||
|             return a + b; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     void testOllamaToolServiceProvidersPresent() throws Exception { | ||||
|         OllamaToolService ann = SampleToolService.class.getAnnotation(OllamaToolService.class); | ||||
|         assertNotNull(ann); | ||||
|         assertArrayEquals(new Class<?>[]{SampleProvider.class}, ann.providers()); | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     void testToolPropertyMetadataOnParameters() throws Exception { | ||||
|         Method m = SampleProvider.class.getDeclaredMethod("sum", int.class, int.class); | ||||
|         Parameter[] params = m.getParameters(); | ||||
|         ToolProperty p0 = params[0].getAnnotation(ToolProperty.class); | ||||
|         ToolProperty p1 = params[1].getAnnotation(ToolProperty.class); | ||||
|         assertNotNull(p0); | ||||
|         assertEquals("a", p0.name()); | ||||
|         assertEquals("first addend", p0.desc()); | ||||
|         assertTrue(p0.required()); | ||||
|  | ||||
|         assertNotNull(p1); | ||||
|         assertEquals("b", p1.name()); | ||||
|         assertEquals("second addend", p1.desc()); | ||||
|         assertFalse(p1.required()); | ||||
|     } | ||||
| } | ||||
| @@ -6,10 +6,10 @@ import org.junit.jupiter.api.Test; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
|  | ||||
| public class TestAuth { | ||||
| class TestAuth { | ||||
|  | ||||
|     @Test | ||||
|     public void testBasicAuthHeaderEncoding() { | ||||
|     void testBasicAuthHeaderEncoding() { | ||||
|         BasicAuth auth = new BasicAuth("alice", "s3cr3t"); | ||||
|         String header = auth.getAuthHeaderValue(); | ||||
|         assertTrue(header.startsWith("Basic ")); | ||||
| @@ -18,7 +18,7 @@ public class TestAuth { | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     public void testBearerAuthHeaderFormat() { | ||||
|     void testBearerAuthHeaderFormat() { | ||||
|         BearerAuth auth = new BearerAuth("abc.def.ghi"); | ||||
|         String header = auth.getAuthHeaderValue(); | ||||
|         assertEquals("Bearer abc.def.ghi", header); | ||||
|   | ||||
| @@ -0,0 +1,43 @@ | ||||
| package io.github.ollama4j.unittests; | ||||
|  | ||||
| import com.fasterxml.jackson.annotation.JsonInclude; | ||||
| import com.fasterxml.jackson.core.JsonProcessingException; | ||||
| import com.fasterxml.jackson.databind.ObjectMapper; | ||||
| import com.fasterxml.jackson.databind.annotation.JsonSerialize; | ||||
| import io.github.ollama4j.utils.BooleanToJsonFormatFlagSerializer; | ||||
| import io.github.ollama4j.utils.Utils; | ||||
| import org.junit.jupiter.api.Test; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
|  | ||||
| class TestBooleanToJsonFormatFlagSerializer { | ||||
|  | ||||
|     static class Holder { | ||||
|         @JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class) | ||||
|         public Boolean formatJson; | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     void testSerializeTrueWritesJsonString() throws JsonProcessingException { | ||||
|         ObjectMapper mapper = Utils.getObjectMapper().copy(); | ||||
|         mapper.setSerializationInclusion(JsonInclude.Include.NON_EMPTY); | ||||
|  | ||||
|         Holder holder = new Holder(); | ||||
|         holder.formatJson = true; | ||||
|  | ||||
|         String json = mapper.writeValueAsString(holder); | ||||
|         assertEquals("{\"formatJson\":\"json\"}", json); | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     void testSerializeFalseOmittedByIsEmpty() throws JsonProcessingException { | ||||
|         ObjectMapper mapper = Utils.getObjectMapper().copy(); | ||||
|         mapper.setSerializationInclusion(JsonInclude.Include.NON_EMPTY); | ||||
|  | ||||
|         Holder holder = new Holder(); | ||||
|         holder.formatJson = false; | ||||
|  | ||||
|         String json = mapper.writeValueAsString(holder); | ||||
|         assertEquals("{}", json); | ||||
|     } | ||||
| } | ||||
| @@ -0,0 +1,32 @@ | ||||
| package io.github.ollama4j.unittests; | ||||
|  | ||||
| import com.fasterxml.jackson.core.JsonProcessingException; | ||||
| import com.fasterxml.jackson.databind.ObjectMapper; | ||||
| import com.fasterxml.jackson.databind.annotation.JsonSerialize; | ||||
| import io.github.ollama4j.utils.FileToBase64Serializer; | ||||
| import io.github.ollama4j.utils.Utils; | ||||
| import org.junit.jupiter.api.Test; | ||||
|  | ||||
| import java.util.List; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
|  | ||||
| public class TestFileToBase64Serializer { | ||||
|  | ||||
|     static class Holder { | ||||
|         @JsonSerialize(using = FileToBase64Serializer.class) | ||||
|         public List<byte[]> images; | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     public void testSerializeByteArraysToBase64Array() throws JsonProcessingException { | ||||
|         ObjectMapper mapper = Utils.getObjectMapper(); | ||||
|  | ||||
|         Holder holder = new Holder(); | ||||
|         holder.images = List.of("hello".getBytes(), "world".getBytes()); | ||||
|  | ||||
|         String json = mapper.writeValueAsString(holder); | ||||
|         // Base64 of "hello" = aGVsbG8=, of "world" = d29ybGQ= | ||||
|         assertEquals("{\"images\":[\"aGVsbG8=\",\"d29ybGQ=\"]}", json); | ||||
|     } | ||||
| } | ||||
| @@ -0,0 +1,22 @@ | ||||
| package io.github.ollama4j.unittests; | ||||
|  | ||||
| import io.github.ollama4j.models.chat.OllamaChatMessage; | ||||
| import io.github.ollama4j.models.chat.OllamaChatMessageRole; | ||||
| import org.json.JSONObject; | ||||
| import org.junit.jupiter.api.Test; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
|  | ||||
| class TestOllamaChatMessage { | ||||
|  | ||||
|     @Test | ||||
|     void testToStringProducesJson() { | ||||
|         OllamaChatMessage msg = new OllamaChatMessage(OllamaChatMessageRole.USER, "hello", null, null, null); | ||||
|         String json = msg.toString(); | ||||
|         JSONObject obj = new JSONObject(json); | ||||
|         assertEquals("user", obj.getString("role")); | ||||
|         assertEquals("hello", obj.getString("content")); | ||||
|         assertTrue(obj.has("tool_calls")); | ||||
|         // thinking and images may or may not be present depending on null handling, just ensure no exception | ||||
|     } | ||||
| } | ||||
| @@ -0,0 +1,44 @@ | ||||
| package io.github.ollama4j.unittests; | ||||
|  | ||||
| import io.github.ollama4j.exceptions.RoleNotFoundException; | ||||
| import io.github.ollama4j.models.chat.OllamaChatMessageRole; | ||||
| import org.junit.jupiter.api.Test; | ||||
|  | ||||
| import java.util.List; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
|  | ||||
| class TestOllamaChatMessageRole { | ||||
|  | ||||
|     @Test | ||||
|     void testStaticRolesRegistered() throws Exception { | ||||
|         List<OllamaChatMessageRole> roles = OllamaChatMessageRole.getRoles(); | ||||
|         assertTrue(roles.contains(OllamaChatMessageRole.SYSTEM)); | ||||
|         assertTrue(roles.contains(OllamaChatMessageRole.USER)); | ||||
|         assertTrue(roles.contains(OllamaChatMessageRole.ASSISTANT)); | ||||
|         assertTrue(roles.contains(OllamaChatMessageRole.TOOL)); | ||||
|  | ||||
|         assertEquals("system", OllamaChatMessageRole.SYSTEM.toString()); | ||||
|         assertEquals("user", OllamaChatMessageRole.USER.toString()); | ||||
|         assertEquals("assistant", OllamaChatMessageRole.ASSISTANT.toString()); | ||||
|         assertEquals("tool", OllamaChatMessageRole.TOOL.toString()); | ||||
|  | ||||
|         assertSame(OllamaChatMessageRole.SYSTEM, OllamaChatMessageRole.getRole("system")); | ||||
|         assertSame(OllamaChatMessageRole.USER, OllamaChatMessageRole.getRole("user")); | ||||
|         assertSame(OllamaChatMessageRole.ASSISTANT, OllamaChatMessageRole.getRole("assistant")); | ||||
|         assertSame(OllamaChatMessageRole.TOOL, OllamaChatMessageRole.getRole("tool")); | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     void testCustomRoleCreationAndLookup() throws Exception { | ||||
|         OllamaChatMessageRole custom = OllamaChatMessageRole.newCustomRole("myrole"); | ||||
|         assertEquals("myrole", custom.toString()); | ||||
|         // custom roles are registered globally (per current implementation), so lookup should succeed | ||||
|         assertSame(custom, OllamaChatMessageRole.getRole("myrole")); | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     void testGetRoleThrowsOnUnknown() { | ||||
|         assertThrows(RoleNotFoundException.class, () -> OllamaChatMessageRole.getRole("does-not-exist")); | ||||
|     } | ||||
| } | ||||
| @@ -0,0 +1,49 @@ | ||||
| package io.github.ollama4j.unittests; | ||||
|  | ||||
| import io.github.ollama4j.models.chat.OllamaChatMessage; | ||||
| import io.github.ollama4j.models.chat.OllamaChatMessageRole; | ||||
| import io.github.ollama4j.models.chat.OllamaChatRequest; | ||||
| import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; | ||||
| import org.junit.jupiter.api.Test; | ||||
|  | ||||
| import java.util.Collections; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
|  | ||||
| class TestOllamaChatRequestBuilder { | ||||
|  | ||||
|     @Test | ||||
|     void testResetClearsMessagesButKeepsModelAndThink() { | ||||
|         OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("my-model") | ||||
|                 .withThinking(true) | ||||
|                 .withMessage(OllamaChatMessageRole.USER, "first"); | ||||
|  | ||||
|         OllamaChatRequest beforeReset = builder.build(); | ||||
|         assertEquals("my-model", beforeReset.getModel()); | ||||
|         assertTrue(beforeReset.isThink()); | ||||
|         assertEquals(1, beforeReset.getMessages().size()); | ||||
|  | ||||
|         builder.reset(); | ||||
|         OllamaChatRequest afterReset = builder.build(); | ||||
|         assertEquals("my-model", afterReset.getModel()); | ||||
|         assertTrue(afterReset.isThink()); | ||||
|         assertNotNull(afterReset.getMessages()); | ||||
|         assertEquals(0, afterReset.getMessages().size()); | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     void testImageUrlFailuresAreIgnoredAndDoNotBreakBuild() { | ||||
|         // Provide clearly invalid URL, builder logs a warning and continues | ||||
|         OllamaChatRequest req = OllamaChatRequestBuilder.getInstance("m") | ||||
|                 .withMessage(OllamaChatMessageRole.USER, "hi", Collections.emptyList(), | ||||
|                         "ht!tp://invalid url \n not a uri") | ||||
|                 .build(); | ||||
|  | ||||
|         assertNotNull(req.getMessages()); | ||||
|         assertEquals(1, req.getMessages().size()); | ||||
|         OllamaChatMessage msg = req.getMessages().get(0); | ||||
|         // images list will be initialized only if any valid URL was added; for invalid URL list can be null | ||||
|         // We just assert that builder didn't crash and message is present with content | ||||
|         assertEquals("hi", msg.getContent()); | ||||
|     } | ||||
| } | ||||
| @@ -0,0 +1,58 @@ | ||||
| package io.github.ollama4j.unittests; | ||||
|  | ||||
| import io.github.ollama4j.utils.OllamaRequestBody; | ||||
| import io.github.ollama4j.utils.Utils; | ||||
| import org.junit.jupiter.api.Test; | ||||
|  | ||||
| import java.io.IOException; | ||||
| import java.nio.ByteBuffer; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.util.concurrent.Flow; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
|  | ||||
| class TestOllamaRequestBody { | ||||
|  | ||||
|     static class SimpleRequest implements OllamaRequestBody { | ||||
|         public String name; | ||||
|         public int value; | ||||
|  | ||||
|         SimpleRequest(String name, int value) { | ||||
|             this.name = name; | ||||
|             this.value = value; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     void testGetBodyPublisherProducesSerializedJson() throws IOException { | ||||
|         SimpleRequest req = new SimpleRequest("abc", 123); | ||||
|  | ||||
|         var publisher = req.getBodyPublisher(); | ||||
|  | ||||
|         StringBuilder data = new StringBuilder(); | ||||
|         publisher.subscribe(new Flow.Subscriber<>() { | ||||
|             @Override | ||||
|             public void onSubscribe(Flow.Subscription subscription) { | ||||
|                 subscription.request(Long.MAX_VALUE); | ||||
|             } | ||||
|  | ||||
|             @Override | ||||
|             public void onNext(ByteBuffer item) { | ||||
|                 data.append(StandardCharsets.UTF_8.decode(item)); | ||||
|             } | ||||
|  | ||||
|             @Override | ||||
|             public void onError(Throwable throwable) { | ||||
|             } | ||||
|  | ||||
|             @Override | ||||
|             public void onComplete() { | ||||
|             } | ||||
|         }); | ||||
|  | ||||
|         // Trigger the publishing by converting it to a string via the same mapper for determinism | ||||
|         String expected = Utils.getObjectMapper().writeValueAsString(req); | ||||
|         // Due to asynchronous nature, expected content already delivered synchronously by StringPublisher | ||||
|         assertEquals(expected, data.toString()); | ||||
|     } | ||||
| } | ||||
| @@ -0,0 +1,46 @@ | ||||
| package io.github.ollama4j.unittests; | ||||
|  | ||||
| import io.github.ollama4j.models.response.OllamaResult; | ||||
| import io.github.ollama4j.tools.OllamaToolsResult; | ||||
| import io.github.ollama4j.tools.ToolFunctionCallSpec; | ||||
| import org.junit.jupiter.api.Test; | ||||
|  | ||||
| import java.util.LinkedHashMap; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
|  | ||||
| public class TestOllamaToolsResult { | ||||
|  | ||||
|     @Test | ||||
|     public void testGetToolResultsTransformsMapToList() { | ||||
|         ToolFunctionCallSpec spec1 = new ToolFunctionCallSpec("fn1", Map.of("a", 1)); | ||||
|         ToolFunctionCallSpec spec2 = new ToolFunctionCallSpec("fn2", Map.of("b", 2)); | ||||
|  | ||||
|         Map<ToolFunctionCallSpec, Object> toolMap = new LinkedHashMap<>(); | ||||
|         toolMap.put(spec1, "r1"); | ||||
|         toolMap.put(spec2, 123); | ||||
|  | ||||
|         OllamaToolsResult tr = new OllamaToolsResult(new OllamaResult("", null, 0L, 200), toolMap); | ||||
|  | ||||
|         List<OllamaToolsResult.ToolResult> list = tr.getToolResults(); | ||||
|         assertEquals(2, list.size()); | ||||
|         assertEquals("fn1", list.get(0).getFunctionName()); | ||||
|         assertEquals(Map.of("a", 1), list.get(0).getFunctionArguments()); | ||||
|         assertEquals("r1", list.get(0).getResult()); | ||||
|  | ||||
|         assertEquals("fn2", list.get(1).getFunctionName()); | ||||
|         assertEquals(Map.of("b", 2), list.get(1).getFunctionArguments()); | ||||
|         assertEquals(123, list.get(1).getResult()); | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     public void testGetToolResultsReturnsEmptyListWhenNull() { | ||||
|         OllamaToolsResult tr = new OllamaToolsResult(); | ||||
|         tr.setToolResults(null); | ||||
|         List<OllamaToolsResult.ToolResult> list = tr.getToolResults(); | ||||
|         assertNotNull(list); | ||||
|         assertTrue(list.isEmpty()); | ||||
|     } | ||||
| } | ||||
| @@ -11,10 +11,10 @@ import java.util.Map; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
|  | ||||
| public class TestOptionsAndUtils { | ||||
| class TestOptionsAndUtils { | ||||
|  | ||||
|     @Test | ||||
|     public void testOptionsBuilderSetsValues() { | ||||
|     void testOptionsBuilderSetsValues() { | ||||
|         Options options = new OptionsBuilder() | ||||
|                 .setMirostat(1) | ||||
|                 .setMirostatEta(0.2f) | ||||
| @@ -58,13 +58,13 @@ public class TestOptionsAndUtils { | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     public void testOptionsBuilderRejectsUnsupportedCustomType() { | ||||
|     void testOptionsBuilderRejectsUnsupportedCustomType() { | ||||
|         OptionsBuilder builder = new OptionsBuilder(); | ||||
|         assertThrows(IllegalArgumentException.class, () -> builder.setCustomOption("bad", new Object())); | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     public void testPromptBuilderBuildsExpectedString() { | ||||
|     void testPromptBuilderBuildsExpectedString() { | ||||
|         String prompt = new PromptBuilder() | ||||
|                 .add("Hello") | ||||
|                 .addLine(", world!") | ||||
| @@ -77,14 +77,14 @@ public class TestOptionsAndUtils { | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     public void testUtilsGetObjectMapperSingletonAndModule() { | ||||
|     void testUtilsGetObjectMapperSingletonAndModule() { | ||||
|         assertSame(Utils.getObjectMapper(), Utils.getObjectMapper()); | ||||
|         // Basic serialization sanity check with JavaTimeModule registered | ||||
|         assertDoesNotThrow(() -> Utils.getObjectMapper().writeValueAsString(java.time.OffsetDateTime.now())); | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     public void testGetFileFromClasspath() { | ||||
|     void testGetFileFromClasspath() { | ||||
|         File f = Utils.getFileFromClasspath("test-config.properties"); | ||||
|         assertTrue(f.exists()); | ||||
|         assertTrue(f.getName().contains("test-config.properties")); | ||||
|   | ||||
| @@ -0,0 +1,86 @@ | ||||
| package io.github.ollama4j.unittests; | ||||
|  | ||||
| import io.github.ollama4j.tools.ReflectionalToolFunction; | ||||
| import org.junit.jupiter.api.Test; | ||||
|  | ||||
| import java.lang.reflect.Method; | ||||
| import java.math.BigDecimal; | ||||
| import java.util.LinkedHashMap; | ||||
| import java.util.Map; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
|  | ||||
| class TestReflectionalToolFunction { | ||||
|  | ||||
|     public static class SampleToolHolder { | ||||
|         public String combine(Integer i, Boolean b, BigDecimal d, String s) { | ||||
|             return String.format("i=%s,b=%s,d=%s,s=%s", i, b, d, s); | ||||
|         } | ||||
|  | ||||
|         public void alwaysThrows() { | ||||
|             throw new IllegalStateException("boom"); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     void testApplyInvokesMethodWithTypeCasting() throws Exception { | ||||
|         SampleToolHolder holder = new SampleToolHolder(); | ||||
|         Method method = SampleToolHolder.class.getMethod("combine", Integer.class, Boolean.class, BigDecimal.class, String.class); | ||||
|  | ||||
|         LinkedHashMap<String, String> propDef = new LinkedHashMap<>(); | ||||
|         // preserve order to match method parameters | ||||
|         propDef.put("i", "java.lang.Integer"); | ||||
|         propDef.put("b", "java.lang.Boolean"); | ||||
|         propDef.put("d", "java.math.BigDecimal"); | ||||
|         propDef.put("s", "java.lang.String"); | ||||
|  | ||||
|         ReflectionalToolFunction fn = new ReflectionalToolFunction(holder, method, propDef); | ||||
|  | ||||
|         Map<String, Object> args = Map.of( | ||||
|                 "i", "42", | ||||
|                 "b", "true", | ||||
|                 "d", "3.14", | ||||
|                 "s", 123 // not a string; should be toString()'d by implementation | ||||
|         ); | ||||
|  | ||||
|         Object result = fn.apply(args); | ||||
|         assertEquals("i=42,b=true,d=3.14,s=123", result); | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     void testTypeCastNullsWhenClassOrValueIsNull() throws Exception { | ||||
|         SampleToolHolder holder = new SampleToolHolder(); | ||||
|         Method method = SampleToolHolder.class.getMethod("combine", Integer.class, Boolean.class, BigDecimal.class, String.class); | ||||
|  | ||||
|         LinkedHashMap<String, String> propDef = new LinkedHashMap<>(); | ||||
|         propDef.put("i", null); // className null -> expect null passed | ||||
|         propDef.put("b", "java.lang.Boolean"); | ||||
|         propDef.put("d", "java.math.BigDecimal"); | ||||
|         propDef.put("s", "java.lang.String"); | ||||
|  | ||||
|         ReflectionalToolFunction fn = new ReflectionalToolFunction(holder, method, propDef); | ||||
|  | ||||
|         Map<String, Object> args = new LinkedHashMap<>(); | ||||
|         args.put("i", "100"); // ignored -> becomes null due to null className | ||||
|         args.put("b", null); // value null -> expect null passed | ||||
|         args.put("d", "1.00"); | ||||
|         args.put("s", "ok"); | ||||
|  | ||||
|         Object result = fn.apply(args); | ||||
|         assertEquals("i=null,b=null,d=1.00,s=ok", result); | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     void testExceptionsAreWrappedWithMeaningfulMessage() throws Exception { | ||||
|         SampleToolHolder holder = new SampleToolHolder(); | ||||
|         Method throwsMethod = SampleToolHolder.class.getMethod("alwaysThrows"); | ||||
|  | ||||
|         LinkedHashMap<String, String> propDef = new LinkedHashMap<>(); | ||||
|  | ||||
|         ReflectionalToolFunction fn = new ReflectionalToolFunction(holder, throwsMethod, propDef); | ||||
|  | ||||
|         RuntimeException ex = assertThrows(RuntimeException.class, () -> fn.apply(Map.of())); | ||||
|         assertTrue(ex.getMessage().contains("Failed to invoke tool: alwaysThrows")); | ||||
|         assertNotNull(ex.getCause()); | ||||
|     } | ||||
| } | ||||
| @@ -9,10 +9,10 @@ import java.util.Map; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
|  | ||||
| public class TestToolRegistry { | ||||
| class TestToolRegistry { | ||||
|  | ||||
|     @Test | ||||
|     public void testAddAndGetToolFunction() { | ||||
|     void testAddAndGetToolFunction() { | ||||
|         ToolRegistry registry = new ToolRegistry(); | ||||
|         ToolFunction fn = args -> "ok:" + args.get("x"); | ||||
|  | ||||
| @@ -29,13 +29,13 @@ public class TestToolRegistry { | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     public void testGetUnknownReturnsNull() { | ||||
|     void testGetUnknownReturnsNull() { | ||||
|         ToolRegistry registry = new ToolRegistry(); | ||||
|         assertNull(registry.getToolFunction("nope")); | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     public void testClearRemovesAll() { | ||||
|     void testClearRemovesAll() { | ||||
|         ToolRegistry registry = new ToolRegistry(); | ||||
|         registry.addTool("a", Tools.ToolSpecification.builder().toolFunction(args -> 1).build()); | ||||
|         registry.addTool("b", Tools.ToolSpecification.builder().toolFunction(args -> 2).build()); | ||||
|   | ||||
| @@ -9,10 +9,10 @@ import java.util.Map; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
|  | ||||
| public class TestToolsPromptBuilder { | ||||
| class TestToolsPromptBuilder { | ||||
|  | ||||
|     @Test | ||||
|     public void testPromptBuilderIncludesToolsAndPrompt() throws JsonProcessingException { | ||||
|     void testPromptBuilderIncludesToolsAndPrompt() throws JsonProcessingException { | ||||
|         Tools.PromptFuncDefinition.Property cityProp = Tools.PromptFuncDefinition.Property.builder() | ||||
|                 .type("string") | ||||
|                 .description("city name") | ||||
|   | ||||
| @@ -30,7 +30,7 @@ public abstract class AbstractSerializationTest<T> { | ||||
|     } | ||||
|  | ||||
|     protected void assertEqualsAfterUnmarshalling(T unmarshalledObject, | ||||
|         T req) { | ||||
|                                                   T req) { | ||||
|         assertEquals(req, unmarshalledObject); | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -34,8 +34,8 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla | ||||
|     @Test | ||||
|     public void testRequestMultipleMessages() { | ||||
|         OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt") | ||||
|         .withMessage(OllamaChatMessageRole.USER, "Some prompt") | ||||
|         .build(); | ||||
|                 .withMessage(OllamaChatMessageRole.USER, "Some prompt") | ||||
|                 .build(); | ||||
|         String jsonRequest = serialize(req); | ||||
|         assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req); | ||||
|     } | ||||
| @@ -52,19 +52,19 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla | ||||
|     public void testRequestWithOptions() { | ||||
|         OptionsBuilder b = new OptionsBuilder(); | ||||
|         OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt") | ||||
|             .withOptions(b.setMirostat(1).build()) | ||||
|             .withOptions(b.setTemperature(1L).build()) | ||||
|             .withOptions(b.setMirostatEta(1L).build()) | ||||
|             .withOptions(b.setMirostatTau(1L).build()) | ||||
|             .withOptions(b.setNumGpu(1).build()) | ||||
|             .withOptions(b.setSeed(1).build()) | ||||
|             .withOptions(b.setTopK(1).build()) | ||||
|             .withOptions(b.setTopP(1).build()) | ||||
|             .withOptions(b.setMinP(1).build()) | ||||
|             .withOptions(b.setCustomOption("cust_float", 1.0f).build()) | ||||
|             .withOptions(b.setCustomOption("cust_int", 1).build()) | ||||
|             .withOptions(b.setCustomOption("cust_str", "custom").build()) | ||||
|             .build(); | ||||
|                 .withOptions(b.setMirostat(1).build()) | ||||
|                 .withOptions(b.setTemperature(1L).build()) | ||||
|                 .withOptions(b.setMirostatEta(1L).build()) | ||||
|                 .withOptions(b.setMirostatTau(1L).build()) | ||||
|                 .withOptions(b.setNumGpu(1).build()) | ||||
|                 .withOptions(b.setSeed(1).build()) | ||||
|                 .withOptions(b.setTopK(1).build()) | ||||
|                 .withOptions(b.setTopP(1).build()) | ||||
|                 .withOptions(b.setMinP(1).build()) | ||||
|                 .withOptions(b.setCustomOption("cust_float", 1.0f).build()) | ||||
|                 .withOptions(b.setCustomOption("cust_int", 1).build()) | ||||
|                 .withOptions(b.setCustomOption("cust_str", "custom").build()) | ||||
|                 .build(); | ||||
|  | ||||
|         String jsonRequest = serialize(req); | ||||
|         OllamaChatRequest deserializeRequest = deserialize(jsonRequest, OllamaChatRequest.class); | ||||
| @@ -87,9 +87,9 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla | ||||
|     public void testRequestWithInvalidCustomOption() { | ||||
|         OptionsBuilder b = new OptionsBuilder(); | ||||
|         assertThrowsExactly(IllegalArgumentException.class, () -> { | ||||
|                 OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt") | ||||
|                 .withOptions(b.setCustomOption("cust_obj", new Object()).build()) | ||||
|                 .build(); | ||||
|             OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt") | ||||
|                     .withOptions(b.setCustomOption("cust_obj", new Object()).build()) | ||||
|                     .build(); | ||||
|         }); | ||||
|     } | ||||
|  | ||||
| @@ -109,7 +109,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla | ||||
|     @Test | ||||
|     public void testWithTemplate() { | ||||
|         OllamaChatRequest req = builder.withTemplate("System Template") | ||||
|             .build(); | ||||
|                 .build(); | ||||
|         String jsonRequest = serialize(req); | ||||
|         assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req); | ||||
|     } | ||||
| @@ -125,7 +125,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla | ||||
|     public void testWithKeepAlive() { | ||||
|         String expectedKeepAlive = "5m"; | ||||
|         OllamaChatRequest req = builder.withKeepAlive(expectedKeepAlive) | ||||
|             .build(); | ||||
|                 .build(); | ||||
|         String jsonRequest = serialize(req); | ||||
|         assertEquals(deserialize(jsonRequest, OllamaChatRequest.class).getKeepAlive(), expectedKeepAlive); | ||||
|     } | ||||
|   | ||||
| @@ -10,29 +10,29 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
|  | ||||
| public class TestEmbedRequestSerialization extends AbstractSerializationTest<OllamaEmbedRequestModel> { | ||||
|  | ||||
|         private OllamaEmbedRequestBuilder builder; | ||||
|     private OllamaEmbedRequestBuilder builder; | ||||
|  | ||||
|         @BeforeEach | ||||
|         public void init() { | ||||
|             builder = OllamaEmbedRequestBuilder.getInstance("DummyModel","DummyPrompt"); | ||||
|         } | ||||
|     @BeforeEach | ||||
|     public void init() { | ||||
|         builder = OllamaEmbedRequestBuilder.getInstance("DummyModel", "DummyPrompt"); | ||||
|     } | ||||
|  | ||||
|             @Test | ||||
|     @Test | ||||
|     public void testRequestOnlyMandatoryFields() { | ||||
|         OllamaEmbedRequestModel req = builder.build(); | ||||
|         String jsonRequest = serialize(req); | ||||
|         assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaEmbedRequestModel.class), req); | ||||
|         assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaEmbedRequestModel.class), req); | ||||
|     } | ||||
|  | ||||
|         @Test | ||||
|         public void testRequestWithOptions() { | ||||
|             OptionsBuilder b = new OptionsBuilder(); | ||||
|             OllamaEmbedRequestModel req = builder | ||||
|                     .withOptions(b.setMirostat(1).build()).build(); | ||||
|     @Test | ||||
|     public void testRequestWithOptions() { | ||||
|         OptionsBuilder b = new OptionsBuilder(); | ||||
|         OllamaEmbedRequestModel req = builder | ||||
|                 .withOptions(b.setMirostat(1).build()).build(); | ||||
|  | ||||
|             String jsonRequest = serialize(req); | ||||
|             OllamaEmbedRequestModel deserializeRequest = deserialize(jsonRequest,OllamaEmbedRequestModel.class); | ||||
|             assertEqualsAfterUnmarshalling(deserializeRequest, req); | ||||
|             assertEquals(1, deserializeRequest.getOptions().get("mirostat")); | ||||
|         } | ||||
|         String jsonRequest = serialize(req); | ||||
|         OllamaEmbedRequestModel deserializeRequest = deserialize(jsonRequest, OllamaEmbedRequestModel.class); | ||||
|         assertEqualsAfterUnmarshalling(deserializeRequest, req); | ||||
|         assertEquals(1, deserializeRequest.getOptions().get("mirostat")); | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 amithkoujalgi
					amithkoujalgi