Extends ChatModels to use Tools and ToolCalls

This commit is contained in:
Markus Klenke
2024-12-06 14:12:33 +01:00
committed by Markus Klenke
parent e9c33ab0b2
commit 12bb10392e
7 changed files with 95 additions and 20 deletions

View File

@@ -10,6 +10,8 @@ import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.ollama4j.tools.ToolFunction;
import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.OptionsBuilder;
import lombok.Data;
import org.junit.jupiter.api.BeforeEach;
@@ -24,9 +26,7 @@ import java.io.InputStream;
import java.net.ConnectException;
import java.net.URISyntaxException;
import java.net.http.HttpConnectTimeoutException;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
@@ -230,18 +230,47 @@ class TestRealAPIs {
void testChatWithTools() {
testEndpointReachability();
try {
ollamaAPI.setVerbose(true);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.toolPrompt(
Tools.PromptFuncDefinition.builder().type("function").function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get employee details from the database")
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
new Tools.PropsBuilder()
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
.build()
)
.required(List.of("employee-name"))
.build()
).build()
).build()
)
.toolFunction(new DBQueryFunction())
.build();
ollamaAPI.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
"Give me the details of the employee named 'Rahul Kumar'?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
System.err.println("Response: " + chatResult);
assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank());
assertTrue(chatResult.getResponse().startsWith("NI"));
assertEquals(3, chatResult.getChatHistory().size());
assertEquals(2, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
@@ -402,6 +431,14 @@ class TestRealAPIs {
}
}
class DBQueryFunction implements ToolFunction {
@Override
public Object apply(Map<String, Object> arguments) {
// perform DB operations here
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString());
}
}
@Data
class Config {
private String ollamaURL;
@@ -426,4 +463,6 @@ class Config {
throw new RuntimeException("Error loading properties", e);
}
}
}