mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-09-16 03:39:05 +02:00
Remove sample prompt utilities and update integration tests
Deleted SamplePrompts.java and sample-db-prompt-template.txt as they are no longer needed. Updated OllamaAPIIntegrationTest to use a new TOOLS_MODEL constant, refactored tool registration and prompt descriptions for employee details, and improved test assertions for tool-based chat interactions.
This commit is contained in:
parent
4df59d8862
commit
97f457575d
@ -1,25 +0,0 @@
|
|||||||
package io.github.ollama4j.utils;
|
|
||||||
|
|
||||||
import io.github.ollama4j.OllamaAPI;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.util.Scanner;
|
|
||||||
|
|
||||||
public class SamplePrompts {
|
|
||||||
public static String getSampleDatabasePromptWithQuestion(String question) throws Exception {
|
|
||||||
ClassLoader classLoader = OllamaAPI.class.getClassLoader();
|
|
||||||
InputStream inputStream = classLoader.getResourceAsStream("sample-db-prompt-template.txt");
|
|
||||||
if (inputStream != null) {
|
|
||||||
Scanner scanner = new Scanner(inputStream);
|
|
||||||
StringBuilder stringBuffer = new StringBuilder();
|
|
||||||
while (scanner.hasNextLine()) {
|
|
||||||
stringBuffer.append(scanner.nextLine()).append("\n");
|
|
||||||
}
|
|
||||||
scanner.close();
|
|
||||||
return stringBuffer.toString().replace("<question>", question);
|
|
||||||
} else {
|
|
||||||
throw new Exception("Sample database question file not found.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,61 +0,0 @@
|
|||||||
"""
|
|
||||||
Following is the database schema.
|
|
||||||
|
|
||||||
DROP TABLE IF EXISTS product_categories;
|
|
||||||
CREATE TABLE IF NOT EXISTS product_categories
|
|
||||||
(
|
|
||||||
category_id INTEGER PRIMARY KEY, -- Unique ID for each category
|
|
||||||
name VARCHAR(50), -- Name of the category
|
|
||||||
parent INTEGER NULL, -- Parent category - for hierarchical categories
|
|
||||||
FOREIGN KEY (parent) REFERENCES product_categories (category_id)
|
|
||||||
);
|
|
||||||
DROP TABLE IF EXISTS products;
|
|
||||||
CREATE TABLE IF NOT EXISTS products
|
|
||||||
(
|
|
||||||
product_id INTEGER PRIMARY KEY, -- Unique ID for each product
|
|
||||||
name VARCHAR(50), -- Name of the product
|
|
||||||
price DECIMAL(10, 2), -- Price of each unit of the product
|
|
||||||
quantity INTEGER, -- Current quantity in stock
|
|
||||||
category_id INTEGER, -- Unique ID for each product
|
|
||||||
FOREIGN KEY (category_id) REFERENCES product_categories (category_id)
|
|
||||||
);
|
|
||||||
DROP TABLE IF EXISTS customers;
|
|
||||||
CREATE TABLE IF NOT EXISTS customers
|
|
||||||
(
|
|
||||||
customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
|
|
||||||
name VARCHAR(50), -- Name of the customer
|
|
||||||
address VARCHAR(100) -- Mailing address of the customer
|
|
||||||
);
|
|
||||||
DROP TABLE IF EXISTS salespeople;
|
|
||||||
CREATE TABLE IF NOT EXISTS salespeople
|
|
||||||
(
|
|
||||||
salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
|
|
||||||
name VARCHAR(50), -- Name of the salesperson
|
|
||||||
region VARCHAR(50) -- Geographic sales region
|
|
||||||
);
|
|
||||||
DROP TABLE IF EXISTS sales;
|
|
||||||
CREATE TABLE IF NOT EXISTS sales
|
|
||||||
(
|
|
||||||
sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
|
|
||||||
product_id INTEGER, -- ID of product sold
|
|
||||||
customer_id INTEGER, -- ID of customer who made the purchase
|
|
||||||
salesperson_id INTEGER, -- ID of salesperson who made the sale
|
|
||||||
sale_date DATE, -- Date the sale occurred
|
|
||||||
quantity INTEGER, -- Quantity of product sold
|
|
||||||
FOREIGN KEY (product_id) REFERENCES products (product_id),
|
|
||||||
FOREIGN KEY (customer_id) REFERENCES customers (customer_id),
|
|
||||||
FOREIGN KEY (salesperson_id) REFERENCES salespeople (salesperson_id)
|
|
||||||
);
|
|
||||||
DROP TABLE IF EXISTS product_suppliers;
|
|
||||||
CREATE TABLE IF NOT EXISTS product_suppliers
|
|
||||||
(
|
|
||||||
supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
|
|
||||||
product_id INTEGER, -- Product ID supplied
|
|
||||||
supply_price DECIMAL(10, 2), -- Unit price charged by supplier
|
|
||||||
FOREIGN KEY (product_id) REFERENCES products (product_id)
|
|
||||||
);
|
|
||||||
|
|
||||||
|
|
||||||
Generate only a valid (syntactically correct) executable Postgres SQL query (without any explanation of the query) for the following question:
|
|
||||||
`<question>`:
|
|
||||||
"""
|
|
@ -46,6 +46,7 @@ class OllamaAPIIntegrationTest {
|
|||||||
private static final String VISION_MODEL = "moondream:1.8b";
|
private static final String VISION_MODEL = "moondream:1.8b";
|
||||||
private static final String THINKING_TOOL_MODEL = "gpt-oss:20b";
|
private static final String THINKING_TOOL_MODEL = "gpt-oss:20b";
|
||||||
private static final String GENERAL_PURPOSE_MODEL = "gemma3:270m";
|
private static final String GENERAL_PURPOSE_MODEL = "gemma3:270m";
|
||||||
|
private static final String TOOLS_MODEL = "mistral:7b";
|
||||||
|
|
||||||
@BeforeAll
|
@BeforeAll
|
||||||
static void setUp() {
|
static void setUp() {
|
||||||
@ -309,16 +310,17 @@ class OllamaAPIIntegrationTest {
|
|||||||
@Order(11)
|
@Order(11)
|
||||||
void testChatWithExplicitToolDefinition() throws OllamaBaseException, IOException, URISyntaxException,
|
void testChatWithExplicitToolDefinition() throws OllamaBaseException, IOException, URISyntaxException,
|
||||||
InterruptedException, ToolInvocationException {
|
InterruptedException, ToolInvocationException {
|
||||||
api.pullModel(THINKING_TOOL_MODEL);
|
String theToolModel = TOOLS_MODEL;
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL);
|
api.pullModel(theToolModel);
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
|
||||||
|
|
||||||
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
final Tools.ToolSpecification employeeDetailsToolSpecification = Tools.ToolSpecification.builder()
|
||||||
.functionName("get-employee-details")
|
.functionName("get-employee-details")
|
||||||
.functionDescription("Get employee details from the database")
|
.functionDescription("Tool to get details of a person or an employee")
|
||||||
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
|
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
|
||||||
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||||
.name("get-employee-details")
|
.name("get-employee-details")
|
||||||
.description("Get employee details from the database")
|
.description("Tool to get details of a person or an employee")
|
||||||
.parameters(Tools.PromptFuncDefinition.Parameters
|
.parameters(Tools.PromptFuncDefinition.Parameters
|
||||||
.builder().type("object")
|
.builder().type("object")
|
||||||
.properties(new Tools.PropsBuilder()
|
.properties(new Tools.PropsBuilder()
|
||||||
@ -358,10 +360,10 @@ class OllamaAPIIntegrationTest {
|
|||||||
arguments.get("employee-phone"));
|
arguments.get("employee-phone"));
|
||||||
}).build();
|
}).build();
|
||||||
|
|
||||||
api.registerTool(databaseQueryToolSpecification);
|
api.registerTool(employeeDetailsToolSpecification);
|
||||||
|
|
||||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||||
"Give me the address of the employee named 'Rahul Kumar'?").build();
|
"Give me the ID of the employee named Rahul Kumar.").build();
|
||||||
requestModel.setOptions(new OptionsBuilder().setTemperature(0.9f).build().getOptionsMap());
|
requestModel.setOptions(new OptionsBuilder().setTemperature(0.9f).build().getOptionsMap());
|
||||||
|
|
||||||
OllamaChatResult chatResult = api.chat(requestModel);
|
OllamaChatResult chatResult = api.chat(requestModel);
|
||||||
@ -387,8 +389,9 @@ class OllamaAPIIntegrationTest {
|
|||||||
@Order(12)
|
@Order(12)
|
||||||
void testChatWithAnnotatedToolsAndSingleParam() throws OllamaBaseException, IOException, InterruptedException,
|
void testChatWithAnnotatedToolsAndSingleParam() throws OllamaBaseException, IOException, InterruptedException,
|
||||||
URISyntaxException, ToolInvocationException {
|
URISyntaxException, ToolInvocationException {
|
||||||
api.pullModel(THINKING_TOOL_MODEL);
|
String theToolModel = TOOLS_MODEL;
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL);
|
api.pullModel(theToolModel);
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
|
||||||
|
|
||||||
api.registerAnnotatedTools();
|
api.registerAnnotatedTools();
|
||||||
|
|
||||||
@ -420,8 +423,9 @@ class OllamaAPIIntegrationTest {
|
|||||||
@Order(13)
|
@Order(13)
|
||||||
void testChatWithAnnotatedToolsAndMultipleParams() throws OllamaBaseException, IOException, URISyntaxException,
|
void testChatWithAnnotatedToolsAndMultipleParams() throws OllamaBaseException, IOException, URISyntaxException,
|
||||||
InterruptedException, ToolInvocationException {
|
InterruptedException, ToolInvocationException {
|
||||||
api.pullModel(THINKING_TOOL_MODEL);
|
String theToolModel = TOOLS_MODEL;
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL);
|
api.pullModel(theToolModel);
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
|
||||||
|
|
||||||
api.registerAnnotatedTools(new AnnotatedTool());
|
api.registerAnnotatedTools(new AnnotatedTool());
|
||||||
|
|
||||||
@ -455,15 +459,18 @@ class OllamaAPIIntegrationTest {
|
|||||||
@Order(14)
|
@Order(14)
|
||||||
void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException,
|
void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException,
|
||||||
InterruptedException, ToolInvocationException {
|
InterruptedException, ToolInvocationException {
|
||||||
api.pullModel(THINKING_TOOL_MODEL);
|
String theToolModel = TOOLS_MODEL;
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL);
|
api.pullModel(theToolModel);
|
||||||
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
|
||||||
|
String expectedEmployeeID = UUID.randomUUID().toString();
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
|
||||||
|
final Tools.ToolSpecification employeeDetailsToolSpecification = Tools.ToolSpecification.builder()
|
||||||
.functionName("get-employee-details")
|
.functionName("get-employee-details")
|
||||||
.functionDescription("Get employee details from the database")
|
.functionDescription("Tool to get details for a person or an employee")
|
||||||
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
|
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
|
||||||
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||||
.name("get-employee-details")
|
.name("get-employee-details")
|
||||||
.description("Get employee details from the database")
|
.description("Tool to get details for a person or an employee")
|
||||||
.parameters(Tools.PromptFuncDefinition.Parameters
|
.parameters(Tools.PromptFuncDefinition.Parameters
|
||||||
.builder().type("object")
|
.builder().type("object")
|
||||||
.properties(new Tools.PropsBuilder()
|
.properties(new Tools.PropsBuilder()
|
||||||
@ -478,14 +485,14 @@ class OllamaAPIIntegrationTest {
|
|||||||
Tools.PromptFuncDefinition.Property
|
Tools.PromptFuncDefinition.Property
|
||||||
.builder()
|
.builder()
|
||||||
.type("string")
|
.type("string")
|
||||||
.description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
|
.description("The address of the employee, Always gives a random address. For example, Roy St, Bengaluru, India")
|
||||||
.required(true)
|
.required(true)
|
||||||
.build())
|
.build())
|
||||||
.withProperty("employee-phone",
|
.withProperty("employee-phone",
|
||||||
Tools.PromptFuncDefinition.Property
|
Tools.PromptFuncDefinition.Property
|
||||||
.builder()
|
.builder()
|
||||||
.type("string")
|
.type("string")
|
||||||
.description("The phone number of the employee. Always return a random value. e.g. 9911002233")
|
.description("The phone number of the employee. Always gives a random phone number. For example, 9911002233")
|
||||||
.required(true)
|
.required(true)
|
||||||
.build())
|
.build())
|
||||||
.build())
|
.build())
|
||||||
@ -499,30 +506,33 @@ class OllamaAPIIntegrationTest {
|
|||||||
// perform DB operations here
|
// perform DB operations here
|
||||||
return String.format(
|
return String.format(
|
||||||
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
|
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
|
||||||
UUID.randomUUID(), arguments.get("employee-name"),
|
expectedEmployeeID, arguments.get("employee-name"),
|
||||||
arguments.get("employee-address"),
|
arguments.get("employee-address"),
|
||||||
arguments.get("employee-phone"));
|
arguments.get("employee-phone"));
|
||||||
}
|
}
|
||||||
}).build();
|
}).build();
|
||||||
|
|
||||||
api.registerTool(databaseQueryToolSpecification);
|
api.registerTool(employeeDetailsToolSpecification);
|
||||||
|
|
||||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
OllamaChatRequest requestModel = builder
|
||||||
"Give me the address of the employee named 'Rahul Kumar'?").build();
|
.withMessage(OllamaChatMessageRole.USER, "Find the ID of employee Rahul Kumar")
|
||||||
|
.withKeepAlive("0m")
|
||||||
|
.withOptions(new OptionsBuilder().setTemperature(0.9f).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
StringBuffer sb = new StringBuffer();
|
StringBuffer sb = new StringBuffer();
|
||||||
|
|
||||||
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
|
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
|
||||||
LOG.info(s);
|
|
||||||
String substring = s.substring(sb.toString().length());
|
String substring = s.substring(sb.toString().length());
|
||||||
LOG.info(substring);
|
|
||||||
sb.append(substring);
|
sb.append(substring);
|
||||||
|
LOG.info(substring);
|
||||||
});
|
});
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
assertNotNull(chatResult.getResponseModel());
|
assertNotNull(chatResult.getResponseModel());
|
||||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||||
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
|
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
|
||||||
assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getContent());
|
assertTrue(sb.toString().toLowerCase().contains(expectedEmployeeID));
|
||||||
|
assertTrue(chatResult.getResponseModel().getMessage().getContent().toLowerCase().contains(expectedEmployeeID));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
Loading…
x
Reference in New Issue
Block a user