Adds first approach to annotation based tool callings using basic java reflection

This commit is contained in:
Markus Klenke
2024-12-27 22:20:34 +01:00
parent 35f5f34196
commit 5e6971cc4a
7 changed files with 304 additions and 4 deletions

View File

@@ -7,8 +7,11 @@ import io.github.ollama4j.models.response.ModelDetail;
import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.ollama4j.samples.AnnotatedTool;
import io.github.ollama4j.tools.OllamaToolCallsFunction;
import io.github.ollama4j.tools.ToolFunction;
import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.tools.annotations.OllamaToolService;
import io.github.ollama4j.utils.OptionsBuilder;
import lombok.Data;
import org.junit.jupiter.api.BeforeEach;
@@ -27,6 +30,8 @@ import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
@OllamaToolService(providers = {AnnotatedTool.class}
)
class TestRealAPIs {
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
@@ -229,7 +234,7 @@ class TestRealAPIs {
@Test
@Order(3)
void testChatWithTools() {
void testChatWithExplicitToolDefinition() {
testEndpointReachability();
try {
ollamaAPI.setVerbose(true);
@@ -275,9 +280,10 @@ class TestRealAPIs {
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size());
assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName());
assertEquals(1, toolCalls.get(0).getFunction().getArguments().size());
Object employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name");
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
assertEquals("get-employee-details", function.getName());
assertEquals(1, function.getArguments().size());
Object employeeName = function.getArguments().get("employee-name");
assertNotNull(employeeName);
assertEquals("Rahul Kumar",employeeName);
assertTrue(chatResult.getChatHistory().size()>2);
@@ -288,6 +294,82 @@ class TestRealAPIs {
}
}
@Test
@Order(3)
void testChatWithAnnotatedToolsAndSingleParam() {
testEndpointReachability();
try {
ollamaAPI.setVerbose(true);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
ollamaAPI.registerAnnotatedTools();
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
"Compute the most important constant in the world using 5 digits")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size());
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
assertEquals("computeMkeConstant", function.getName());
assertEquals(1, function.getArguments().size());
Object noOfDigits = function.getArguments().get("noOfDigits");
assertNotNull(noOfDigits);
assertEquals("5",noOfDigits);
assertTrue(chatResult.getChatHistory().size()>2);
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
assertNull(finalToolCalls);
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testChatWithAnnotatedToolsAndMultipleParams() {
testEndpointReachability();
try {
ollamaAPI.setVerbose(true);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
ollamaAPI.registerAnnotatedTools();
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
"Greet Pedro with a lot of hearts and respond to me, " +
"and state how many emojis have been in your greeting")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size());
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
assertEquals("sayHello", function.getName());
assertEquals(2, function.getArguments().size());
Object name = function.getArguments().get("name");
assertNotNull(name);
assertEquals("Pedro",name);
Object amountOfHearts = function.getArguments().get("amountOfHearts");
assertNotNull(amountOfHearts);
assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1);
assertTrue(chatResult.getChatHistory().size()>2);
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
assertNull(finalToolCalls);
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testChatWithToolsAndStream() {

View File

@@ -0,0 +1,21 @@
package io.github.ollama4j.samples;
import io.github.ollama4j.tools.annotations.ToolProperty;
import io.github.ollama4j.tools.annotations.ToolSpec;
import java.math.BigDecimal;
public class AnnotatedTool {
@ToolSpec(desc = "Computes the most important constant all around the globe!")
public String computeMkeConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){
return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString();
}
@ToolSpec(desc = "Says hello to a friend!")
public String sayHello(@ToolProperty(name = "name",desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name="amountOfHearts",desc = "amount of heart emojis that should be used", required = false) Integer amountOfHearts) {
String hearts = amountOfHearts!=null ? "".repeat(amountOfHearts) : "";
return "Hello " + name +" ("+someRandomProperty+") " + hearts;
}
}