mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-11-04 18:40:40 +01:00
Adds first approach to annotation based tool callings using basic java reflection
This commit is contained in:
@@ -7,8 +7,11 @@ import io.github.ollama4j.models.response.ModelDetail;
|
||||
import io.github.ollama4j.models.response.OllamaResult;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||
import io.github.ollama4j.samples.AnnotatedTool;
|
||||
import io.github.ollama4j.tools.OllamaToolCallsFunction;
|
||||
import io.github.ollama4j.tools.ToolFunction;
|
||||
import io.github.ollama4j.tools.Tools;
|
||||
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
import lombok.Data;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
@@ -27,6 +30,8 @@ import java.util.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@OllamaToolService(providers = {AnnotatedTool.class}
|
||||
)
|
||||
class TestRealAPIs {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
|
||||
@@ -229,7 +234,7 @@ class TestRealAPIs {
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void testChatWithTools() {
|
||||
void testChatWithExplicitToolDefinition() {
|
||||
testEndpointReachability();
|
||||
try {
|
||||
ollamaAPI.setVerbose(true);
|
||||
@@ -275,9 +280,10 @@ class TestRealAPIs {
|
||||
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||
assertEquals(1, toolCalls.size());
|
||||
assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName());
|
||||
assertEquals(1, toolCalls.get(0).getFunction().getArguments().size());
|
||||
Object employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name");
|
||||
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||
assertEquals("get-employee-details", function.getName());
|
||||
assertEquals(1, function.getArguments().size());
|
||||
Object employeeName = function.getArguments().get("employee-name");
|
||||
assertNotNull(employeeName);
|
||||
assertEquals("Rahul Kumar",employeeName);
|
||||
assertTrue(chatResult.getChatHistory().size()>2);
|
||||
@@ -288,6 +294,82 @@ class TestRealAPIs {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void testChatWithAnnotatedToolsAndSingleParam() {
|
||||
testEndpointReachability();
|
||||
try {
|
||||
ollamaAPI.setVerbose(true);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||
|
||||
ollamaAPI.registerAnnotatedTools();
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Compute the most important constant in the world using 5 digits")
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||
assertEquals(1, toolCalls.size());
|
||||
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||
assertEquals("computeMkeConstant", function.getName());
|
||||
assertEquals(1, function.getArguments().size());
|
||||
Object noOfDigits = function.getArguments().get("noOfDigits");
|
||||
assertNotNull(noOfDigits);
|
||||
assertEquals("5",noOfDigits);
|
||||
assertTrue(chatResult.getChatHistory().size()>2);
|
||||
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||
assertNull(finalToolCalls);
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void testChatWithAnnotatedToolsAndMultipleParams() {
|
||||
testEndpointReachability();
|
||||
try {
|
||||
ollamaAPI.setVerbose(true);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||
|
||||
ollamaAPI.registerAnnotatedTools();
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Greet Pedro with a lot of hearts and respond to me, " +
|
||||
"and state how many emojis have been in your greeting")
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||
assertEquals(1, toolCalls.size());
|
||||
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||
assertEquals("sayHello", function.getName());
|
||||
assertEquals(2, function.getArguments().size());
|
||||
Object name = function.getArguments().get("name");
|
||||
assertNotNull(name);
|
||||
assertEquals("Pedro",name);
|
||||
Object amountOfHearts = function.getArguments().get("amountOfHearts");
|
||||
assertNotNull(amountOfHearts);
|
||||
assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1);
|
||||
assertTrue(chatResult.getChatHistory().size()>2);
|
||||
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||
assertNull(finalToolCalls);
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void testChatWithToolsAndStream() {
|
||||
|
||||
21
src/test/java/io/github/ollama4j/samples/AnnotatedTool.java
Normal file
21
src/test/java/io/github/ollama4j/samples/AnnotatedTool.java
Normal file
@@ -0,0 +1,21 @@
|
||||
package io.github.ollama4j.samples;
|
||||
|
||||
import io.github.ollama4j.tools.annotations.ToolProperty;
|
||||
import io.github.ollama4j.tools.annotations.ToolSpec;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
public class AnnotatedTool {
|
||||
|
||||
@ToolSpec(desc = "Computes the most important constant all around the globe!")
|
||||
public String computeMkeConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){
|
||||
return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString();
|
||||
}
|
||||
|
||||
@ToolSpec(desc = "Says hello to a friend!")
|
||||
public String sayHello(@ToolProperty(name = "name",desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name="amountOfHearts",desc = "amount of heart emojis that should be used", required = false) Integer amountOfHearts) {
|
||||
String hearts = amountOfHearts!=null ? "♡".repeat(amountOfHearts) : "";
|
||||
return "Hello " + name +" ("+someRandomProperty+") " + hearts;
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user