mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-09-16 03:39:05 +02:00
Add unit tests for annotations, serialization, and tool functions
- Introduced TestAnnotations.java to validate OllamaToolService and ToolProperty annotations. - Added TestBooleanToJsonFormatFlagSerializer.java to test serialization of Boolean values. - Created TestFileToBase64Serializer.java for testing byte array serialization to Base64. - Implemented TestOllamaChatMessage.java to ensure correct JSON representation of chat messages. - Developed TestOllamaChatMessageRole.java to verify role registration and custom role creation. - Added TestOllamaChatRequestBuilder.java to test message handling and request building. - Created TestOllamaRequestBody.java to validate request body serialization. - Implemented TestOllamaToolsResult.java to ensure correct transformation of tool results. - Added TestOptionsAndUtils.java to test options builder and utility functions. - Created TestReflectionalToolFunction.java to validate method invocation and argument handling. - Implemented TestToolRegistry.java to ensure tool registration and retrieval functionality. - Developed TestToolsPromptBuilder.java to verify prompt builder includes tools and prompts correctly. - Added serialization tests in TestChatRequestSerialization.java and TestEmbedRequestSerialization.java.
This commit is contained in:
parent
fddd753a48
commit
9036d9e7c6
@ -0,0 +1,50 @@
|
|||||||
|
package io.github.ollama4j.unittests;
|
||||||
|
|
||||||
|
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||||
|
import io.github.ollama4j.tools.annotations.ToolProperty;
|
||||||
|
import io.github.ollama4j.tools.annotations.ToolSpec;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.lang.reflect.Parameter;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
class TestAnnotations {
|
||||||
|
|
||||||
|
@OllamaToolService(providers = {SampleProvider.class})
|
||||||
|
static class SampleToolService {
|
||||||
|
}
|
||||||
|
|
||||||
|
static class SampleProvider {
|
||||||
|
@ToolSpec(name = "sum", desc = "adds two numbers")
|
||||||
|
public int sum(@ToolProperty(name = "a", desc = "first addend") int a,
|
||||||
|
@ToolProperty(name = "b", desc = "second addend", required = false) int b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testOllamaToolServiceProvidersPresent() throws Exception {
|
||||||
|
OllamaToolService ann = SampleToolService.class.getAnnotation(OllamaToolService.class);
|
||||||
|
assertNotNull(ann);
|
||||||
|
assertArrayEquals(new Class<?>[]{SampleProvider.class}, ann.providers());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testToolPropertyMetadataOnParameters() throws Exception {
|
||||||
|
Method m = SampleProvider.class.getDeclaredMethod("sum", int.class, int.class);
|
||||||
|
Parameter[] params = m.getParameters();
|
||||||
|
ToolProperty p0 = params[0].getAnnotation(ToolProperty.class);
|
||||||
|
ToolProperty p1 = params[1].getAnnotation(ToolProperty.class);
|
||||||
|
assertNotNull(p0);
|
||||||
|
assertEquals("a", p0.name());
|
||||||
|
assertEquals("first addend", p0.desc());
|
||||||
|
assertTrue(p0.required());
|
||||||
|
|
||||||
|
assertNotNull(p1);
|
||||||
|
assertEquals("b", p1.name());
|
||||||
|
assertEquals("second addend", p1.desc());
|
||||||
|
assertFalse(p1.required());
|
||||||
|
}
|
||||||
|
}
|
@ -6,10 +6,10 @@ import org.junit.jupiter.api.Test;
|
|||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestAuth {
|
class TestAuth {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicAuthHeaderEncoding() {
|
void testBasicAuthHeaderEncoding() {
|
||||||
BasicAuth auth = new BasicAuth("alice", "s3cr3t");
|
BasicAuth auth = new BasicAuth("alice", "s3cr3t");
|
||||||
String header = auth.getAuthHeaderValue();
|
String header = auth.getAuthHeaderValue();
|
||||||
assertTrue(header.startsWith("Basic "));
|
assertTrue(header.startsWith("Basic "));
|
||||||
@ -18,7 +18,7 @@ public class TestAuth {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBearerAuthHeaderFormat() {
|
void testBearerAuthHeaderFormat() {
|
||||||
BearerAuth auth = new BearerAuth("abc.def.ghi");
|
BearerAuth auth = new BearerAuth("abc.def.ghi");
|
||||||
String header = auth.getAuthHeaderValue();
|
String header = auth.getAuthHeaderValue();
|
||||||
assertEquals("Bearer abc.def.ghi", header);
|
assertEquals("Bearer abc.def.ghi", header);
|
||||||
|
@ -0,0 +1,43 @@
|
|||||||
|
package io.github.ollama4j.unittests;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||||
|
import io.github.ollama4j.utils.BooleanToJsonFormatFlagSerializer;
|
||||||
|
import io.github.ollama4j.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
class TestBooleanToJsonFormatFlagSerializer {
|
||||||
|
|
||||||
|
static class Holder {
|
||||||
|
@JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class)
|
||||||
|
public Boolean formatJson;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testSerializeTrueWritesJsonString() throws JsonProcessingException {
|
||||||
|
ObjectMapper mapper = Utils.getObjectMapper().copy();
|
||||||
|
mapper.setSerializationInclusion(JsonInclude.Include.NON_EMPTY);
|
||||||
|
|
||||||
|
Holder holder = new Holder();
|
||||||
|
holder.formatJson = true;
|
||||||
|
|
||||||
|
String json = mapper.writeValueAsString(holder);
|
||||||
|
assertEquals("{\"formatJson\":\"json\"}", json);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testSerializeFalseOmittedByIsEmpty() throws JsonProcessingException {
|
||||||
|
ObjectMapper mapper = Utils.getObjectMapper().copy();
|
||||||
|
mapper.setSerializationInclusion(JsonInclude.Include.NON_EMPTY);
|
||||||
|
|
||||||
|
Holder holder = new Holder();
|
||||||
|
holder.formatJson = false;
|
||||||
|
|
||||||
|
String json = mapper.writeValueAsString(holder);
|
||||||
|
assertEquals("{}", json);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,32 @@
|
|||||||
|
package io.github.ollama4j.unittests;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||||
|
import io.github.ollama4j.utils.FileToBase64Serializer;
|
||||||
|
import io.github.ollama4j.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class TestFileToBase64Serializer {
|
||||||
|
|
||||||
|
static class Holder {
|
||||||
|
@JsonSerialize(using = FileToBase64Serializer.class)
|
||||||
|
public List<byte[]> images;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSerializeByteArraysToBase64Array() throws JsonProcessingException {
|
||||||
|
ObjectMapper mapper = Utils.getObjectMapper();
|
||||||
|
|
||||||
|
Holder holder = new Holder();
|
||||||
|
holder.images = List.of("hello".getBytes(), "world".getBytes());
|
||||||
|
|
||||||
|
String json = mapper.writeValueAsString(holder);
|
||||||
|
// Base64 of "hello" = aGVsbG8=, of "world" = d29ybGQ=
|
||||||
|
assertEquals("{\"images\":[\"aGVsbG8=\",\"d29ybGQ=\"]}", json);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,22 @@
|
|||||||
|
package io.github.ollama4j.unittests;
|
||||||
|
|
||||||
|
import io.github.ollama4j.models.chat.OllamaChatMessage;
|
||||||
|
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||||
|
import org.json.JSONObject;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
class TestOllamaChatMessage {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testToStringProducesJson() {
|
||||||
|
OllamaChatMessage msg = new OllamaChatMessage(OllamaChatMessageRole.USER, "hello", null, null, null);
|
||||||
|
String json = msg.toString();
|
||||||
|
JSONObject obj = new JSONObject(json);
|
||||||
|
assertEquals("user", obj.getString("role"));
|
||||||
|
assertEquals("hello", obj.getString("content"));
|
||||||
|
assertTrue(obj.has("tool_calls"));
|
||||||
|
// thinking and images may or may not be present depending on null handling, just ensure no exception
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,44 @@
|
|||||||
|
package io.github.ollama4j.unittests;
|
||||||
|
|
||||||
|
import io.github.ollama4j.exceptions.RoleNotFoundException;
|
||||||
|
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
class TestOllamaChatMessageRole {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testStaticRolesRegistered() throws Exception {
|
||||||
|
List<OllamaChatMessageRole> roles = OllamaChatMessageRole.getRoles();
|
||||||
|
assertTrue(roles.contains(OllamaChatMessageRole.SYSTEM));
|
||||||
|
assertTrue(roles.contains(OllamaChatMessageRole.USER));
|
||||||
|
assertTrue(roles.contains(OllamaChatMessageRole.ASSISTANT));
|
||||||
|
assertTrue(roles.contains(OllamaChatMessageRole.TOOL));
|
||||||
|
|
||||||
|
assertEquals("system", OllamaChatMessageRole.SYSTEM.toString());
|
||||||
|
assertEquals("user", OllamaChatMessageRole.USER.toString());
|
||||||
|
assertEquals("assistant", OllamaChatMessageRole.ASSISTANT.toString());
|
||||||
|
assertEquals("tool", OllamaChatMessageRole.TOOL.toString());
|
||||||
|
|
||||||
|
assertSame(OllamaChatMessageRole.SYSTEM, OllamaChatMessageRole.getRole("system"));
|
||||||
|
assertSame(OllamaChatMessageRole.USER, OllamaChatMessageRole.getRole("user"));
|
||||||
|
assertSame(OllamaChatMessageRole.ASSISTANT, OllamaChatMessageRole.getRole("assistant"));
|
||||||
|
assertSame(OllamaChatMessageRole.TOOL, OllamaChatMessageRole.getRole("tool"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testCustomRoleCreationAndLookup() throws Exception {
|
||||||
|
OllamaChatMessageRole custom = OllamaChatMessageRole.newCustomRole("myrole");
|
||||||
|
assertEquals("myrole", custom.toString());
|
||||||
|
// custom roles are registered globally (per current implementation), so lookup should succeed
|
||||||
|
assertSame(custom, OllamaChatMessageRole.getRole("myrole"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testGetRoleThrowsOnUnknown() {
|
||||||
|
assertThrows(RoleNotFoundException.class, () -> OllamaChatMessageRole.getRole("does-not-exist"));
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,49 @@
|
|||||||
|
package io.github.ollama4j.unittests;
|
||||||
|
|
||||||
|
import io.github.ollama4j.models.chat.OllamaChatMessage;
|
||||||
|
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||||
|
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||||
|
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
class TestOllamaChatRequestBuilder {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testResetClearsMessagesButKeepsModelAndThink() {
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("my-model")
|
||||||
|
.withThinking(true)
|
||||||
|
.withMessage(OllamaChatMessageRole.USER, "first");
|
||||||
|
|
||||||
|
OllamaChatRequest beforeReset = builder.build();
|
||||||
|
assertEquals("my-model", beforeReset.getModel());
|
||||||
|
assertTrue(beforeReset.isThink());
|
||||||
|
assertEquals(1, beforeReset.getMessages().size());
|
||||||
|
|
||||||
|
builder.reset();
|
||||||
|
OllamaChatRequest afterReset = builder.build();
|
||||||
|
assertEquals("my-model", afterReset.getModel());
|
||||||
|
assertTrue(afterReset.isThink());
|
||||||
|
assertNotNull(afterReset.getMessages());
|
||||||
|
assertEquals(0, afterReset.getMessages().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testImageUrlFailuresAreIgnoredAndDoNotBreakBuild() {
|
||||||
|
// Provide clearly invalid URL, builder logs a warning and continues
|
||||||
|
OllamaChatRequest req = OllamaChatRequestBuilder.getInstance("m")
|
||||||
|
.withMessage(OllamaChatMessageRole.USER, "hi", Collections.emptyList(),
|
||||||
|
"ht!tp://invalid url \n not a uri")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
assertNotNull(req.getMessages());
|
||||||
|
assertEquals(1, req.getMessages().size());
|
||||||
|
OllamaChatMessage msg = req.getMessages().get(0);
|
||||||
|
// images list will be initialized only if any valid URL was added; for invalid URL list can be null
|
||||||
|
// We just assert that builder didn't crash and message is present with content
|
||||||
|
assertEquals("hi", msg.getContent());
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,58 @@
|
|||||||
|
package io.github.ollama4j.unittests;
|
||||||
|
|
||||||
|
import io.github.ollama4j.utils.OllamaRequestBody;
|
||||||
|
import io.github.ollama4j.utils.Utils;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.concurrent.Flow;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
class TestOllamaRequestBody {
|
||||||
|
|
||||||
|
static class SimpleRequest implements OllamaRequestBody {
|
||||||
|
public String name;
|
||||||
|
public int value;
|
||||||
|
|
||||||
|
SimpleRequest(String name, int value) {
|
||||||
|
this.name = name;
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testGetBodyPublisherProducesSerializedJson() throws IOException {
|
||||||
|
SimpleRequest req = new SimpleRequest("abc", 123);
|
||||||
|
|
||||||
|
var publisher = req.getBodyPublisher();
|
||||||
|
|
||||||
|
StringBuilder data = new StringBuilder();
|
||||||
|
publisher.subscribe(new Flow.Subscriber<>() {
|
||||||
|
@Override
|
||||||
|
public void onSubscribe(Flow.Subscription subscription) {
|
||||||
|
subscription.request(Long.MAX_VALUE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onNext(ByteBuffer item) {
|
||||||
|
data.append(StandardCharsets.UTF_8.decode(item));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(Throwable throwable) {
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onComplete() {
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Trigger the publishing by converting it to a string via the same mapper for determinism
|
||||||
|
String expected = Utils.getObjectMapper().writeValueAsString(req);
|
||||||
|
// Due to asynchronous nature, expected content already delivered synchronously by StringPublisher
|
||||||
|
assertEquals(expected, data.toString());
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,46 @@
|
|||||||
|
package io.github.ollama4j.unittests;
|
||||||
|
|
||||||
|
import io.github.ollama4j.models.response.OllamaResult;
|
||||||
|
import io.github.ollama4j.tools.OllamaToolsResult;
|
||||||
|
import io.github.ollama4j.tools.ToolFunctionCallSpec;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
public class TestOllamaToolsResult {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetToolResultsTransformsMapToList() {
|
||||||
|
ToolFunctionCallSpec spec1 = new ToolFunctionCallSpec("fn1", Map.of("a", 1));
|
||||||
|
ToolFunctionCallSpec spec2 = new ToolFunctionCallSpec("fn2", Map.of("b", 2));
|
||||||
|
|
||||||
|
Map<ToolFunctionCallSpec, Object> toolMap = new LinkedHashMap<>();
|
||||||
|
toolMap.put(spec1, "r1");
|
||||||
|
toolMap.put(spec2, 123);
|
||||||
|
|
||||||
|
OllamaToolsResult tr = new OllamaToolsResult(new OllamaResult("", null, 0L, 200), toolMap);
|
||||||
|
|
||||||
|
List<OllamaToolsResult.ToolResult> list = tr.getToolResults();
|
||||||
|
assertEquals(2, list.size());
|
||||||
|
assertEquals("fn1", list.get(0).getFunctionName());
|
||||||
|
assertEquals(Map.of("a", 1), list.get(0).getFunctionArguments());
|
||||||
|
assertEquals("r1", list.get(0).getResult());
|
||||||
|
|
||||||
|
assertEquals("fn2", list.get(1).getFunctionName());
|
||||||
|
assertEquals(Map.of("b", 2), list.get(1).getFunctionArguments());
|
||||||
|
assertEquals(123, list.get(1).getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetToolResultsReturnsEmptyListWhenNull() {
|
||||||
|
OllamaToolsResult tr = new OllamaToolsResult();
|
||||||
|
tr.setToolResults(null);
|
||||||
|
List<OllamaToolsResult.ToolResult> list = tr.getToolResults();
|
||||||
|
assertNotNull(list);
|
||||||
|
assertTrue(list.isEmpty());
|
||||||
|
}
|
||||||
|
}
|
@ -11,10 +11,10 @@ import java.util.Map;
|
|||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestOptionsAndUtils {
|
class TestOptionsAndUtils {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testOptionsBuilderSetsValues() {
|
void testOptionsBuilderSetsValues() {
|
||||||
Options options = new OptionsBuilder()
|
Options options = new OptionsBuilder()
|
||||||
.setMirostat(1)
|
.setMirostat(1)
|
||||||
.setMirostatEta(0.2f)
|
.setMirostatEta(0.2f)
|
||||||
@ -58,13 +58,13 @@ public class TestOptionsAndUtils {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testOptionsBuilderRejectsUnsupportedCustomType() {
|
void testOptionsBuilderRejectsUnsupportedCustomType() {
|
||||||
OptionsBuilder builder = new OptionsBuilder();
|
OptionsBuilder builder = new OptionsBuilder();
|
||||||
assertThrows(IllegalArgumentException.class, () -> builder.setCustomOption("bad", new Object()));
|
assertThrows(IllegalArgumentException.class, () -> builder.setCustomOption("bad", new Object()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPromptBuilderBuildsExpectedString() {
|
void testPromptBuilderBuildsExpectedString() {
|
||||||
String prompt = new PromptBuilder()
|
String prompt = new PromptBuilder()
|
||||||
.add("Hello")
|
.add("Hello")
|
||||||
.addLine(", world!")
|
.addLine(", world!")
|
||||||
@ -77,14 +77,14 @@ public class TestOptionsAndUtils {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testUtilsGetObjectMapperSingletonAndModule() {
|
void testUtilsGetObjectMapperSingletonAndModule() {
|
||||||
assertSame(Utils.getObjectMapper(), Utils.getObjectMapper());
|
assertSame(Utils.getObjectMapper(), Utils.getObjectMapper());
|
||||||
// Basic serialization sanity check with JavaTimeModule registered
|
// Basic serialization sanity check with JavaTimeModule registered
|
||||||
assertDoesNotThrow(() -> Utils.getObjectMapper().writeValueAsString(java.time.OffsetDateTime.now()));
|
assertDoesNotThrow(() -> Utils.getObjectMapper().writeValueAsString(java.time.OffsetDateTime.now()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGetFileFromClasspath() {
|
void testGetFileFromClasspath() {
|
||||||
File f = Utils.getFileFromClasspath("test-config.properties");
|
File f = Utils.getFileFromClasspath("test-config.properties");
|
||||||
assertTrue(f.exists());
|
assertTrue(f.exists());
|
||||||
assertTrue(f.getName().contains("test-config.properties"));
|
assertTrue(f.getName().contains("test-config.properties"));
|
||||||
|
@ -0,0 +1,86 @@
|
|||||||
|
package io.github.ollama4j.unittests;
|
||||||
|
|
||||||
|
import io.github.ollama4j.tools.ReflectionalToolFunction;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
class TestReflectionalToolFunction {
|
||||||
|
|
||||||
|
public static class SampleToolHolder {
|
||||||
|
public String combine(Integer i, Boolean b, BigDecimal d, String s) {
|
||||||
|
return String.format("i=%s,b=%s,d=%s,s=%s", i, b, d, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void alwaysThrows() {
|
||||||
|
throw new IllegalStateException("boom");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testApplyInvokesMethodWithTypeCasting() throws Exception {
|
||||||
|
SampleToolHolder holder = new SampleToolHolder();
|
||||||
|
Method method = SampleToolHolder.class.getMethod("combine", Integer.class, Boolean.class, BigDecimal.class, String.class);
|
||||||
|
|
||||||
|
LinkedHashMap<String, String> propDef = new LinkedHashMap<>();
|
||||||
|
// preserve order to match method parameters
|
||||||
|
propDef.put("i", "java.lang.Integer");
|
||||||
|
propDef.put("b", "java.lang.Boolean");
|
||||||
|
propDef.put("d", "java.math.BigDecimal");
|
||||||
|
propDef.put("s", "java.lang.String");
|
||||||
|
|
||||||
|
ReflectionalToolFunction fn = new ReflectionalToolFunction(holder, method, propDef);
|
||||||
|
|
||||||
|
Map<String, Object> args = Map.of(
|
||||||
|
"i", "42",
|
||||||
|
"b", "true",
|
||||||
|
"d", "3.14",
|
||||||
|
"s", 123 // not a string; should be toString()'d by implementation
|
||||||
|
);
|
||||||
|
|
||||||
|
Object result = fn.apply(args);
|
||||||
|
assertEquals("i=42,b=true,d=3.14,s=123", result);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testTypeCastNullsWhenClassOrValueIsNull() throws Exception {
|
||||||
|
SampleToolHolder holder = new SampleToolHolder();
|
||||||
|
Method method = SampleToolHolder.class.getMethod("combine", Integer.class, Boolean.class, BigDecimal.class, String.class);
|
||||||
|
|
||||||
|
LinkedHashMap<String, String> propDef = new LinkedHashMap<>();
|
||||||
|
propDef.put("i", null); // className null -> expect null passed
|
||||||
|
propDef.put("b", "java.lang.Boolean");
|
||||||
|
propDef.put("d", "java.math.BigDecimal");
|
||||||
|
propDef.put("s", "java.lang.String");
|
||||||
|
|
||||||
|
ReflectionalToolFunction fn = new ReflectionalToolFunction(holder, method, propDef);
|
||||||
|
|
||||||
|
Map<String, Object> args = new LinkedHashMap<>();
|
||||||
|
args.put("i", "100"); // ignored -> becomes null due to null className
|
||||||
|
args.put("b", null); // value null -> expect null passed
|
||||||
|
args.put("d", "1.00");
|
||||||
|
args.put("s", "ok");
|
||||||
|
|
||||||
|
Object result = fn.apply(args);
|
||||||
|
assertEquals("i=null,b=null,d=1.00,s=ok", result);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testExceptionsAreWrappedWithMeaningfulMessage() throws Exception {
|
||||||
|
SampleToolHolder holder = new SampleToolHolder();
|
||||||
|
Method throwsMethod = SampleToolHolder.class.getMethod("alwaysThrows");
|
||||||
|
|
||||||
|
LinkedHashMap<String, String> propDef = new LinkedHashMap<>();
|
||||||
|
|
||||||
|
ReflectionalToolFunction fn = new ReflectionalToolFunction(holder, throwsMethod, propDef);
|
||||||
|
|
||||||
|
RuntimeException ex = assertThrows(RuntimeException.class, () -> fn.apply(Map.of()));
|
||||||
|
assertTrue(ex.getMessage().contains("Failed to invoke tool: alwaysThrows"));
|
||||||
|
assertNotNull(ex.getCause());
|
||||||
|
}
|
||||||
|
}
|
@ -9,10 +9,10 @@ import java.util.Map;
|
|||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestToolRegistry {
|
class TestToolRegistry {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAddAndGetToolFunction() {
|
void testAddAndGetToolFunction() {
|
||||||
ToolRegistry registry = new ToolRegistry();
|
ToolRegistry registry = new ToolRegistry();
|
||||||
ToolFunction fn = args -> "ok:" + args.get("x");
|
ToolFunction fn = args -> "ok:" + args.get("x");
|
||||||
|
|
||||||
@ -29,13 +29,13 @@ public class TestToolRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGetUnknownReturnsNull() {
|
void testGetUnknownReturnsNull() {
|
||||||
ToolRegistry registry = new ToolRegistry();
|
ToolRegistry registry = new ToolRegistry();
|
||||||
assertNull(registry.getToolFunction("nope"));
|
assertNull(registry.getToolFunction("nope"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testClearRemovesAll() {
|
void testClearRemovesAll() {
|
||||||
ToolRegistry registry = new ToolRegistry();
|
ToolRegistry registry = new ToolRegistry();
|
||||||
registry.addTool("a", Tools.ToolSpecification.builder().toolFunction(args -> 1).build());
|
registry.addTool("a", Tools.ToolSpecification.builder().toolFunction(args -> 1).build());
|
||||||
registry.addTool("b", Tools.ToolSpecification.builder().toolFunction(args -> 2).build());
|
registry.addTool("b", Tools.ToolSpecification.builder().toolFunction(args -> 2).build());
|
||||||
|
@ -9,10 +9,10 @@ import java.util.Map;
|
|||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestToolsPromptBuilder {
|
class TestToolsPromptBuilder {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPromptBuilderIncludesToolsAndPrompt() throws JsonProcessingException {
|
void testPromptBuilderIncludesToolsAndPrompt() throws JsonProcessingException {
|
||||||
Tools.PromptFuncDefinition.Property cityProp = Tools.PromptFuncDefinition.Property.builder()
|
Tools.PromptFuncDefinition.Property cityProp = Tools.PromptFuncDefinition.Property.builder()
|
||||||
.type("string")
|
.type("string")
|
||||||
.description("city name")
|
.description("city name")
|
||||||
|
@ -30,7 +30,7 @@ public abstract class AbstractSerializationTest<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected void assertEqualsAfterUnmarshalling(T unmarshalledObject,
|
protected void assertEqualsAfterUnmarshalling(T unmarshalledObject,
|
||||||
T req) {
|
T req) {
|
||||||
assertEquals(req, unmarshalledObject);
|
assertEquals(req, unmarshalledObject);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,8 +34,8 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
|
|||||||
@Test
|
@Test
|
||||||
public void testRequestMultipleMessages() {
|
public void testRequestMultipleMessages() {
|
||||||
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt")
|
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt")
|
||||||
.withMessage(OllamaChatMessageRole.USER, "Some prompt")
|
.withMessage(OllamaChatMessageRole.USER, "Some prompt")
|
||||||
.build();
|
.build();
|
||||||
String jsonRequest = serialize(req);
|
String jsonRequest = serialize(req);
|
||||||
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);
|
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);
|
||||||
}
|
}
|
||||||
@ -52,19 +52,19 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
|
|||||||
public void testRequestWithOptions() {
|
public void testRequestWithOptions() {
|
||||||
OptionsBuilder b = new OptionsBuilder();
|
OptionsBuilder b = new OptionsBuilder();
|
||||||
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
|
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
|
||||||
.withOptions(b.setMirostat(1).build())
|
.withOptions(b.setMirostat(1).build())
|
||||||
.withOptions(b.setTemperature(1L).build())
|
.withOptions(b.setTemperature(1L).build())
|
||||||
.withOptions(b.setMirostatEta(1L).build())
|
.withOptions(b.setMirostatEta(1L).build())
|
||||||
.withOptions(b.setMirostatTau(1L).build())
|
.withOptions(b.setMirostatTau(1L).build())
|
||||||
.withOptions(b.setNumGpu(1).build())
|
.withOptions(b.setNumGpu(1).build())
|
||||||
.withOptions(b.setSeed(1).build())
|
.withOptions(b.setSeed(1).build())
|
||||||
.withOptions(b.setTopK(1).build())
|
.withOptions(b.setTopK(1).build())
|
||||||
.withOptions(b.setTopP(1).build())
|
.withOptions(b.setTopP(1).build())
|
||||||
.withOptions(b.setMinP(1).build())
|
.withOptions(b.setMinP(1).build())
|
||||||
.withOptions(b.setCustomOption("cust_float", 1.0f).build())
|
.withOptions(b.setCustomOption("cust_float", 1.0f).build())
|
||||||
.withOptions(b.setCustomOption("cust_int", 1).build())
|
.withOptions(b.setCustomOption("cust_int", 1).build())
|
||||||
.withOptions(b.setCustomOption("cust_str", "custom").build())
|
.withOptions(b.setCustomOption("cust_str", "custom").build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
String jsonRequest = serialize(req);
|
String jsonRequest = serialize(req);
|
||||||
OllamaChatRequest deserializeRequest = deserialize(jsonRequest, OllamaChatRequest.class);
|
OllamaChatRequest deserializeRequest = deserialize(jsonRequest, OllamaChatRequest.class);
|
||||||
@ -87,9 +87,9 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
|
|||||||
public void testRequestWithInvalidCustomOption() {
|
public void testRequestWithInvalidCustomOption() {
|
||||||
OptionsBuilder b = new OptionsBuilder();
|
OptionsBuilder b = new OptionsBuilder();
|
||||||
assertThrowsExactly(IllegalArgumentException.class, () -> {
|
assertThrowsExactly(IllegalArgumentException.class, () -> {
|
||||||
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
|
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
|
||||||
.withOptions(b.setCustomOption("cust_obj", new Object()).build())
|
.withOptions(b.setCustomOption("cust_obj", new Object()).build())
|
||||||
.build();
|
.build();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
|
|||||||
@Test
|
@Test
|
||||||
public void testWithTemplate() {
|
public void testWithTemplate() {
|
||||||
OllamaChatRequest req = builder.withTemplate("System Template")
|
OllamaChatRequest req = builder.withTemplate("System Template")
|
||||||
.build();
|
.build();
|
||||||
String jsonRequest = serialize(req);
|
String jsonRequest = serialize(req);
|
||||||
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);
|
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);
|
||||||
}
|
}
|
||||||
@ -125,7 +125,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
|
|||||||
public void testWithKeepAlive() {
|
public void testWithKeepAlive() {
|
||||||
String expectedKeepAlive = "5m";
|
String expectedKeepAlive = "5m";
|
||||||
OllamaChatRequest req = builder.withKeepAlive(expectedKeepAlive)
|
OllamaChatRequest req = builder.withKeepAlive(expectedKeepAlive)
|
||||||
.build();
|
.build();
|
||||||
String jsonRequest = serialize(req);
|
String jsonRequest = serialize(req);
|
||||||
assertEquals(deserialize(jsonRequest, OllamaChatRequest.class).getKeepAlive(), expectedKeepAlive);
|
assertEquals(deserialize(jsonRequest, OllamaChatRequest.class).getKeepAlive(), expectedKeepAlive);
|
||||||
}
|
}
|
||||||
|
@ -10,29 +10,29 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||||||
|
|
||||||
public class TestEmbedRequestSerialization extends AbstractSerializationTest<OllamaEmbedRequestModel> {
|
public class TestEmbedRequestSerialization extends AbstractSerializationTest<OllamaEmbedRequestModel> {
|
||||||
|
|
||||||
private OllamaEmbedRequestBuilder builder;
|
private OllamaEmbedRequestBuilder builder;
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void init() {
|
public void init() {
|
||||||
builder = OllamaEmbedRequestBuilder.getInstance("DummyModel","DummyPrompt");
|
builder = OllamaEmbedRequestBuilder.getInstance("DummyModel", "DummyPrompt");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRequestOnlyMandatoryFields() {
|
public void testRequestOnlyMandatoryFields() {
|
||||||
OllamaEmbedRequestModel req = builder.build();
|
OllamaEmbedRequestModel req = builder.build();
|
||||||
String jsonRequest = serialize(req);
|
String jsonRequest = serialize(req);
|
||||||
assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaEmbedRequestModel.class), req);
|
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaEmbedRequestModel.class), req);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRequestWithOptions() {
|
public void testRequestWithOptions() {
|
||||||
OptionsBuilder b = new OptionsBuilder();
|
OptionsBuilder b = new OptionsBuilder();
|
||||||
OllamaEmbedRequestModel req = builder
|
OllamaEmbedRequestModel req = builder
|
||||||
.withOptions(b.setMirostat(1).build()).build();
|
.withOptions(b.setMirostat(1).build()).build();
|
||||||
|
|
||||||
String jsonRequest = serialize(req);
|
String jsonRequest = serialize(req);
|
||||||
OllamaEmbedRequestModel deserializeRequest = deserialize(jsonRequest,OllamaEmbedRequestModel.class);
|
OllamaEmbedRequestModel deserializeRequest = deserialize(jsonRequest, OllamaEmbedRequestModel.class);
|
||||||
assertEqualsAfterUnmarshalling(deserializeRequest, req);
|
assertEqualsAfterUnmarshalling(deserializeRequest, req);
|
||||||
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
|
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user