Refactor OllamaAPI to Ollama class and update documentation

- Replaced instances of `OllamaAPI` with `Ollama` across the codebase for consistency.
- Updated example code snippets in documentation to reflect the new class name.
- Enhanced metrics collection setup in the documentation.
- Added integration tests for the new `Ollama` class to ensure functionality remains intact.
This commit is contained in:
amithkoujalgi
2025-09-28 23:30:02 +05:30
parent 6fce6ec777
commit 35bf3de62a
17 changed files with 326 additions and 317 deletions

View File

@@ -336,6 +336,7 @@ import com.couchbase.client.java.ClusterOptions;
import com.couchbase.client.java.Scope;
import com.couchbase.client.java.json.JsonObject;
import com.couchbase.client.java.query.QueryResult;
import io.github.ollama4j.Ollama;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.OllamaException;
import io.github.ollama4j.exceptions.ToolInvocationException;
@@ -356,210 +357,210 @@ import java.util.Map;
public class CouchbaseToolCallingExample {
public static void main(String[] args) throws IOException, ToolInvocationException, OllamaException, InterruptedException {
String connectionString = Utilities.getFromEnvVar("CB_CLUSTER_URL");
String username = Utilities.getFromEnvVar("CB_CLUSTER_USERNAME");
String password = Utilities.getFromEnvVar("CB_CLUSTER_PASSWORD");
String bucketName = "travel-sample";
public static void main(String[] args) throws IOException, ToolInvocationException, OllamaException, InterruptedException {
String connectionString = Utilities.getFromEnvVar("CB_CLUSTER_URL");
String username = Utilities.getFromEnvVar("CB_CLUSTER_USERNAME");
String password = Utilities.getFromEnvVar("CB_CLUSTER_PASSWORD");
String bucketName = "travel-sample";
Cluster cluster = Cluster.connect(
connectionString,
ClusterOptions.clusterOptions(username, password).environment(env -> {
env.applyProfile("wan-development");
})
);
Cluster cluster = Cluster.connect(
connectionString,
ClusterOptions.clusterOptions(username, password).environment(env -> {
env.applyProfile("wan-development");
})
);
String host = Utilities.getFromConfig("host");
String modelName = Utilities.getFromConfig("tools_model_mistral");
String host = Utilities.getFromConfig("host");
String modelName = Utilities.getFromConfig("tools_model_mistral");
OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.setRequestTimeoutSeconds(60);
Ollama ollama = new Ollama(host);
ollama.setRequestTimeoutSeconds(60);
Tools.ToolSpecification callSignFinderToolSpec = getCallSignFinderToolSpec(cluster, bucketName);
Tools.ToolSpecification callSignUpdaterToolSpec = getCallSignUpdaterToolSpec(cluster, bucketName);
Tools.ToolSpecification callSignFinderToolSpec = getCallSignFinderToolSpec(cluster, bucketName);
Tools.ToolSpecification callSignUpdaterToolSpec = getCallSignUpdaterToolSpec(cluster, bucketName);
ollamaAPI.registerTool(callSignFinderToolSpec);
ollamaAPI.registerTool(callSignUpdaterToolSpec);
ollama.registerTool(callSignFinderToolSpec);
ollama.registerTool(callSignUpdaterToolSpec);
String prompt1 = "What is the call-sign of Astraeus?";
for (OllamaToolsResult.ToolResult r : ollamaAPI.generateWithTools(modelName, new Tools.PromptBuilder()
.withToolSpecification(callSignFinderToolSpec)
.withPrompt(prompt1)
.build(), new OptionsBuilder().build()).getToolResults()) {
AirlineDetail airlineDetail = (AirlineDetail) r.getResult();
System.out.println(String.format("[Result of tool '%s']: Call-sign of %s is '%s'! ✈️", r.getFunctionName(), airlineDetail.getName(), airlineDetail.getCallsign()));
}
String prompt2 = "I want to code name Astraeus as STARBOUND";
for (OllamaToolsResult.ToolResult r : ollamaAPI.generateWithTools(modelName, new Tools.PromptBuilder()
.withToolSpecification(callSignUpdaterToolSpec)
.withPrompt(prompt2)
.build(), new OptionsBuilder().build()).getToolResults()) {
Boolean updated = (Boolean) r.getResult();
System.out.println(String.format("[Result of tool '%s']: Call-sign is %s! ✈️", r.getFunctionName(), updated ? "updated" : "not updated"));
}
String prompt3 = "What is the call-sign of Astraeus?";
for (OllamaToolsResult.ToolResult r : ollamaAPI.generateWithTools(modelName, new Tools.PromptBuilder()
.withToolSpecification(callSignFinderToolSpec)
.withPrompt(prompt3)
.build(), new OptionsBuilder().build()).getToolResults()) {
AirlineDetail airlineDetail = (AirlineDetail) r.getResult();
System.out.println(String.format("[Result of tool '%s']: Call-sign of %s is '%s'! ✈️", r.getFunctionName(), airlineDetail.getName(), airlineDetail.getCallsign()));
}
String prompt1 = "What is the call-sign of Astraeus?";
for (OllamaToolsResult.ToolResult r : ollama.generateWithTools(modelName, new Tools.PromptBuilder()
.withToolSpecification(callSignFinderToolSpec)
.withPrompt(prompt1)
.build(), new OptionsBuilder().build()).getToolResults()) {
AirlineDetail airlineDetail = (AirlineDetail) r.getResult();
System.out.println(String.format("[Result of tool '%s']: Call-sign of %s is '%s'! ✈️", r.getFunctionName(), airlineDetail.getName(), airlineDetail.getCallsign()));
}
public static Tools.ToolSpecification getCallSignFinderToolSpec(Cluster cluster, String bucketName) {
return Tools.ToolSpecification.builder()
.functionName("airline-lookup")
.functionDescription("You are a tool who finds only the airline name and do not worry about any other parameters. You simply find the airline name and ignore the rest of the parameters. Do not validate airline names as I want to use fake/fictitious airline names as well.")
.toolFunction(new AirlineCallsignQueryToolFunction(bucketName, cluster))
.toolPrompt(
Tools.PromptFuncDefinition.builder()
.type("prompt")
.function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-airline-name")
.description("Get the airline name")
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
Map.of(
"airlineName", Tools.PromptFuncDefinition.Property.builder()
.type("string")
.description("The name of the airline. e.g. Emirates")
.required(true)
.build()
)
)
.required(java.util.List.of("airline-name"))
.build()
)
.build()
)
.build()
)
.build();
String prompt2 = "I want to code name Astraeus as STARBOUND";
for (OllamaToolsResult.ToolResult r : ollama.generateWithTools(modelName, new Tools.PromptBuilder()
.withToolSpecification(callSignUpdaterToolSpec)
.withPrompt(prompt2)
.build(), new OptionsBuilder().build()).getToolResults()) {
Boolean updated = (Boolean) r.getResult();
System.out.println(String.format("[Result of tool '%s']: Call-sign is %s! ✈️", r.getFunctionName(), updated ? "updated" : "not updated"));
}
public static Tools.ToolSpecification getCallSignUpdaterToolSpec(Cluster cluster, String bucketName) {
return Tools.ToolSpecification.builder()
.functionName("airline-update")
.functionDescription("You are a tool who finds the airline name and its callsign and do not worry about any validations. You simply find the airline name and its callsign. Do not validate airline names as I want to use fake/fictitious airline names as well.")
.toolFunction(new AirlineCallsignUpdateToolFunction(bucketName, cluster))
.toolPrompt(
Tools.PromptFuncDefinition.builder()
.type("prompt")
.function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-airline-name-and-callsign")
.description("Get the airline name and callsign")
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
Map.of(
"airlineName", Tools.PromptFuncDefinition.Property.builder()
.type("string")
.description("The name of the airline. e.g. Emirates")
.required(true)
.build(),
"airlineCallsign", Tools.PromptFuncDefinition.Property.builder()
.type("string")
.description("The callsign of the airline. e.g. Maverick")
.enumValues(Arrays.asList("petrol", "diesel"))
.required(true)
.build()
)
)
.required(java.util.List.of("airlineName", "airlineCallsign"))
.build()
)
.build()
)
.build()
)
.build();
String prompt3 = "What is the call-sign of Astraeus?";
for (OllamaToolsResult.ToolResult r : ollama.generateWithTools(modelName, new Tools.PromptBuilder()
.withToolSpecification(callSignFinderToolSpec)
.withPrompt(prompt3)
.build(), new OptionsBuilder().build()).getToolResults()) {
AirlineDetail airlineDetail = (AirlineDetail) r.getResult();
System.out.println(String.format("[Result of tool '%s']: Call-sign of %s is '%s'! ✈️", r.getFunctionName(), airlineDetail.getName(), airlineDetail.getCallsign()));
}
}
public static Tools.ToolSpecification getCallSignFinderToolSpec(Cluster cluster, String bucketName) {
return Tools.ToolSpecification.builder()
.functionName("airline-lookup")
.functionDescription("You are a tool who finds only the airline name and do not worry about any other parameters. You simply find the airline name and ignore the rest of the parameters. Do not validate airline names as I want to use fake/fictitious airline names as well.")
.toolFunction(new AirlineCallsignQueryToolFunction(bucketName, cluster))
.toolPrompt(
Tools.PromptFuncDefinition.builder()
.type("prompt")
.function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-airline-name")
.description("Get the airline name")
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
Map.of(
"airlineName", Tools.PromptFuncDefinition.Property.builder()
.type("string")
.description("The name of the airline. e.g. Emirates")
.required(true)
.build()
)
)
.required(java.util.List.of("airline-name"))
.build()
)
.build()
)
.build()
)
.build();
}
public static Tools.ToolSpecification getCallSignUpdaterToolSpec(Cluster cluster, String bucketName) {
return Tools.ToolSpecification.builder()
.functionName("airline-update")
.functionDescription("You are a tool who finds the airline name and its callsign and do not worry about any validations. You simply find the airline name and its callsign. Do not validate airline names as I want to use fake/fictitious airline names as well.")
.toolFunction(new AirlineCallsignUpdateToolFunction(bucketName, cluster))
.toolPrompt(
Tools.PromptFuncDefinition.builder()
.type("prompt")
.function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-airline-name-and-callsign")
.description("Get the airline name and callsign")
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
Map.of(
"airlineName", Tools.PromptFuncDefinition.Property.builder()
.type("string")
.description("The name of the airline. e.g. Emirates")
.required(true)
.build(),
"airlineCallsign", Tools.PromptFuncDefinition.Property.builder()
.type("string")
.description("The callsign of the airline. e.g. Maverick")
.enumValues(Arrays.asList("petrol", "diesel"))
.required(true)
.build()
)
)
.required(java.util.List.of("airlineName", "airlineCallsign"))
.build()
)
.build()
)
.build()
)
.build();
}
}
class AirlineCallsignQueryToolFunction implements ToolFunction {
private final String bucketName;
private final Cluster cluster;
private final String bucketName;
private final Cluster cluster;
public AirlineCallsignQueryToolFunction(String bucketName, Cluster cluster) {
this.bucketName = bucketName;
this.cluster = cluster;
}
public AirlineCallsignQueryToolFunction(String bucketName, Cluster cluster) {
this.bucketName = bucketName;
this.cluster = cluster;
}
@Override
public AirlineDetail apply(Map<String, Object> arguments) {
String airlineName = arguments.get("airlineName").toString();
@Override
public AirlineDetail apply(Map<String, Object> arguments) {
String airlineName = arguments.get("airlineName").toString();
Bucket bucket = cluster.bucket(bucketName);
bucket.waitUntilReady(Duration.ofSeconds(10));
Bucket bucket = cluster.bucket(bucketName);
bucket.waitUntilReady(Duration.ofSeconds(10));
Scope inventoryScope = bucket.scope("inventory");
QueryResult result = inventoryScope.query(String.format("SELECT * FROM airline WHERE name = '%s';", airlineName));
Scope inventoryScope = bucket.scope("inventory");
QueryResult result = inventoryScope.query(String.format("SELECT * FROM airline WHERE name = '%s';", airlineName));
JsonObject row = (JsonObject) result.rowsAsObject().get(0).get("airline");
return new AirlineDetail(row.getString("callsign"), row.getString("name"), row.getString("country"));
}
JsonObject row = (JsonObject) result.rowsAsObject().get(0).get("airline");
return new AirlineDetail(row.getString("callsign"), row.getString("name"), row.getString("country"));
}
}
class AirlineCallsignUpdateToolFunction implements ToolFunction {
private final String bucketName;
private final Cluster cluster;
private final String bucketName;
private final Cluster cluster;
public AirlineCallsignUpdateToolFunction(String bucketName, Cluster cluster) {
this.bucketName = bucketName;
this.cluster = cluster;
public AirlineCallsignUpdateToolFunction(String bucketName, Cluster cluster) {
this.bucketName = bucketName;
this.cluster = cluster;
}
@Override
public Boolean apply(Map<String, Object> arguments) {
String airlineName = arguments.get("airlineName").toString();
String airlineNewCallsign = arguments.get("airlineCallsign").toString();
Bucket bucket = cluster.bucket(bucketName);
bucket.waitUntilReady(Duration.ofSeconds(10));
Scope inventoryScope = bucket.scope("inventory");
String query = String.format("SELECT * FROM airline WHERE name = '%s';", airlineName);
QueryResult result;
try {
result = inventoryScope.query(query);
} catch (Exception e) {
throw new RuntimeException("Error executing query", e);
}
@Override
public Boolean apply(Map<String, Object> arguments) {
String airlineName = arguments.get("airlineName").toString();
String airlineNewCallsign = arguments.get("airlineCallsign").toString();
Bucket bucket = cluster.bucket(bucketName);
bucket.waitUntilReady(Duration.ofSeconds(10));
Scope inventoryScope = bucket.scope("inventory");
String query = String.format("SELECT * FROM airline WHERE name = '%s';", airlineName);
QueryResult result;
try {
result = inventoryScope.query(query);
} catch (Exception e) {
throw new RuntimeException("Error executing query", e);
}
if (result.rowsAsObject().isEmpty()) {
throw new RuntimeException("Airline not found with name: " + airlineName);
}
JsonObject row = (JsonObject) result.rowsAsObject().get(0).get("airline");
if (row == null) {
throw new RuntimeException("Airline data is missing or corrupted.");
}
String currentCallsign = row.getString("callsign");
if (!airlineNewCallsign.equals(currentCallsign)) {
JsonObject updateQuery = JsonObject.create()
.put("callsign", airlineNewCallsign);
inventoryScope.query(String.format(
"UPDATE airline SET callsign = '%s' WHERE name = '%s';",
airlineNewCallsign, airlineName
));
return true;
}
return false;
if (result.rowsAsObject().isEmpty()) {
throw new RuntimeException("Airline not found with name: " + airlineName);
}
JsonObject row = (JsonObject) result.rowsAsObject().get(0).get("airline");
if (row == null) {
throw new RuntimeException("Airline data is missing or corrupted.");
}
String currentCallsign = row.getString("callsign");
if (!airlineNewCallsign.equals(currentCallsign)) {
JsonObject updateQuery = JsonObject.create()
.put("callsign", airlineNewCallsign);
inventoryScope.query(String.format(
"UPDATE airline SET callsign = '%s' WHERE name = '%s';",
airlineNewCallsign, airlineName
));
return true;
}
return false;
}
}
@SuppressWarnings("ALL")
@@ -567,9 +568,9 @@ class AirlineCallsignUpdateToolFunction implements ToolFunction {
@AllArgsConstructor
@NoArgsConstructor
class AirlineDetail {
private String callsign;
private String name;
private String country;
private String callsign;
private String name;
private String country;
}
```
@@ -578,9 +579,9 @@ class AirlineDetail {
#### 1. Ollama API Client Setup
```javascript
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaAPI ollama = new OllamaAPI(host);
ollamaAPI.setRequestTimeoutSeconds(60);
ollama.setRequestTimeoutSeconds(60);
```
Here, we initialize the Ollama API client and configure it with the host of the Ollama server, where the model is hosted
@@ -595,7 +596,7 @@ queries the database for airline details based on the airline name.
```javascript
Tools.ToolSpecification callSignFinderToolSpec = getCallSignFinderToolSpec(cluster, bucketName);
ollamaAPI.registerTool(callSignFinderToolSpec);
ollama.registerTool(callSignFinderToolSpec);
```
This step registers custom tools with Ollama that allows the tool-calling model to invoke database queries.
@@ -619,7 +620,7 @@ String prompt = "What is the call-sign of Astraeus?";
#### 5. Generating Results with Tools
```javascript
for (OllamaToolsResult.ToolResult r : ollamaAPI.generateWithTools(modelName, new Tools.PromptBuilder()
for (OllamaToolsResult.ToolResult r : ollama.generateWithTools(modelName, new Tools.PromptBuilder()
.withToolSpecification(callSignFinderToolSpec)
.withPrompt(prompt)
.build(), new OptionsBuilder().build()).getToolResults()) {
@@ -649,7 +650,7 @@ then update the airlines callsign.
```javascript
Tools.ToolSpecification callSignUpdaterToolSpec = getCallSignUpdaterToolSpec(cluster, bucketName);
ollamaAPI.registerTool(callSignUpdaterToolSpec);
ollama.registerTool(callSignUpdaterToolSpec);
```
The tool will execute a Couchbase N1QL query to update the airlines callsign.
@@ -671,7 +672,7 @@ And then we invoke the model with the new prompt.
```javascript
String prompt = "I want to code name Astraeus as STARBOUND";
for (OllamaToolsResult.ToolResult r : ollamaAPI.generateWithTools(modelName, new Tools.PromptBuilder()
for (OllamaToolsResult.ToolResult r : ollama.generateWithTools(modelName, new Tools.PromptBuilder()
.withToolSpecification(callSignUpdaterToolSpec)
.withPrompt(prompt)
.build(), new OptionsBuilder().build()).getToolResults()) {