mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-10-14 09:28:58 +02:00
Refactor OllamaAPI to Ollama class and update documentation
- Replaced instances of `OllamaAPI` with `Ollama` across the codebase for consistency. - Updated example code snippets in documentation to reflect the new class name. - Enhanced metrics collection setup in the documentation. - Added integration tests for the new `Ollama` class to ensure functionality remains intact.
This commit is contained in:
parent
6fce6ec777
commit
35bf3de62a
@ -17,11 +17,14 @@ The metrics integration provides the following metrics:
|
|||||||
### 1. Enable Metrics Collection
|
### 1. Enable Metrics Collection
|
||||||
|
|
||||||
```java
|
```java
|
||||||
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
|
|
||||||
// Create API instance with metrics enabled
|
// Create API instance with metrics enabled
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI();
|
Ollama ollama = new Ollama();
|
||||||
ollamaAPI.setMetricsEnabled(true);
|
ollamaAPI.
|
||||||
|
|
||||||
|
setMetricsEnabled(true);
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Start Metrics Server
|
### 2. Start Metrics Server
|
||||||
@ -38,11 +41,11 @@ System.out.println("Metrics available at: http://localhost:8080/metrics");
|
|||||||
|
|
||||||
```java
|
```java
|
||||||
// All API calls are automatically instrumented
|
// All API calls are automatically instrumented
|
||||||
boolean isReachable = ollamaAPI.ping();
|
boolean isReachable = ollama.ping();
|
||||||
|
|
||||||
Map<String, Object> format = new HashMap<>();
|
Map<String, Object> format = new HashMap<>();
|
||||||
format.put("type", "json");
|
format.put("type", "json");
|
||||||
OllamaResult result = ollamaAPI.generateWithFormat(
|
OllamaResult result = ollama.generateWithFormat(
|
||||||
"llama2",
|
"llama2",
|
||||||
"Generate a JSON object",
|
"Generate a JSON object",
|
||||||
format
|
format
|
||||||
@ -100,13 +103,13 @@ ollama_tokens_generated_total{model_name="llama2"} 150.0
|
|||||||
### Enable/Disable Metrics
|
### Enable/Disable Metrics
|
||||||
|
|
||||||
```java
|
```java
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI();
|
OllamaAPI ollama = new OllamaAPI();
|
||||||
|
|
||||||
// Enable metrics collection
|
// Enable metrics collection
|
||||||
ollamaAPI.setMetricsEnabled(true);
|
ollama.setMetricsEnabled(true);
|
||||||
|
|
||||||
// Disable metrics collection (default)
|
// Disable metrics collection (default)
|
||||||
ollamaAPI.setMetricsEnabled(false);
|
ollama.setMetricsEnabled(false);
|
||||||
```
|
```
|
||||||
|
|
||||||
### Custom Metrics Server
|
### Custom Metrics Server
|
||||||
@ -149,14 +152,14 @@ You can create Grafana dashboards using the metrics. Some useful queries:
|
|||||||
|
|
||||||
- Metrics collection adds minimal overhead (~1-2% in most cases)
|
- Metrics collection adds minimal overhead (~1-2% in most cases)
|
||||||
- Metrics are collected asynchronously and don't block API calls
|
- Metrics are collected asynchronously and don't block API calls
|
||||||
- You can disable metrics in production if needed: `ollamaAPI.setMetricsEnabled(false)`
|
- You can disable metrics in production if needed: `ollama.setMetricsEnabled(false)`
|
||||||
- The metrics server uses minimal resources
|
- The metrics server uses minimal resources
|
||||||
|
|
||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
### Metrics Not Appearing
|
### Metrics Not Appearing
|
||||||
|
|
||||||
1. Ensure metrics are enabled: `ollamaAPI.setMetricsEnabled(true)`
|
1. Ensure metrics are enabled: `ollama.setMetricsEnabled(true)`
|
||||||
2. Check that the metrics server is running: `http://localhost:8080/metrics`
|
2. Check that the metrics server is running: `http://localhost:8080/metrics`
|
||||||
3. Verify API calls are being made (metrics only appear after API usage)
|
3. Verify API calls are being made (metrics only appear after API usage)
|
||||||
|
|
||||||
|
@ -336,6 +336,7 @@ import com.couchbase.client.java.ClusterOptions;
|
|||||||
import com.couchbase.client.java.Scope;
|
import com.couchbase.client.java.Scope;
|
||||||
import com.couchbase.client.java.json.JsonObject;
|
import com.couchbase.client.java.json.JsonObject;
|
||||||
import com.couchbase.client.java.query.QueryResult;
|
import com.couchbase.client.java.query.QueryResult;
|
||||||
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
import io.github.ollama4j.exceptions.OllamaException;
|
import io.github.ollama4j.exceptions.OllamaException;
|
||||||
import io.github.ollama4j.exceptions.ToolInvocationException;
|
import io.github.ollama4j.exceptions.ToolInvocationException;
|
||||||
@ -356,210 +357,210 @@ import java.util.Map;
|
|||||||
|
|
||||||
public class CouchbaseToolCallingExample {
|
public class CouchbaseToolCallingExample {
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException, ToolInvocationException, OllamaException, InterruptedException {
|
public static void main(String[] args) throws IOException, ToolInvocationException, OllamaException, InterruptedException {
|
||||||
String connectionString = Utilities.getFromEnvVar("CB_CLUSTER_URL");
|
String connectionString = Utilities.getFromEnvVar("CB_CLUSTER_URL");
|
||||||
String username = Utilities.getFromEnvVar("CB_CLUSTER_USERNAME");
|
String username = Utilities.getFromEnvVar("CB_CLUSTER_USERNAME");
|
||||||
String password = Utilities.getFromEnvVar("CB_CLUSTER_PASSWORD");
|
String password = Utilities.getFromEnvVar("CB_CLUSTER_PASSWORD");
|
||||||
String bucketName = "travel-sample";
|
String bucketName = "travel-sample";
|
||||||
|
|
||||||
Cluster cluster = Cluster.connect(
|
Cluster cluster = Cluster.connect(
|
||||||
connectionString,
|
connectionString,
|
||||||
ClusterOptions.clusterOptions(username, password).environment(env -> {
|
ClusterOptions.clusterOptions(username, password).environment(env -> {
|
||||||
env.applyProfile("wan-development");
|
env.applyProfile("wan-development");
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
String host = Utilities.getFromConfig("host");
|
String host = Utilities.getFromConfig("host");
|
||||||
String modelName = Utilities.getFromConfig("tools_model_mistral");
|
String modelName = Utilities.getFromConfig("tools_model_mistral");
|
||||||
|
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
Ollama ollama = new Ollama(host);
|
||||||
ollamaAPI.setRequestTimeoutSeconds(60);
|
ollama.setRequestTimeoutSeconds(60);
|
||||||
|
|
||||||
Tools.ToolSpecification callSignFinderToolSpec = getCallSignFinderToolSpec(cluster, bucketName);
|
Tools.ToolSpecification callSignFinderToolSpec = getCallSignFinderToolSpec(cluster, bucketName);
|
||||||
Tools.ToolSpecification callSignUpdaterToolSpec = getCallSignUpdaterToolSpec(cluster, bucketName);
|
Tools.ToolSpecification callSignUpdaterToolSpec = getCallSignUpdaterToolSpec(cluster, bucketName);
|
||||||
|
|
||||||
ollamaAPI.registerTool(callSignFinderToolSpec);
|
ollama.registerTool(callSignFinderToolSpec);
|
||||||
ollamaAPI.registerTool(callSignUpdaterToolSpec);
|
ollama.registerTool(callSignUpdaterToolSpec);
|
||||||
|
|
||||||
String prompt1 = "What is the call-sign of Astraeus?";
|
String prompt1 = "What is the call-sign of Astraeus?";
|
||||||
for (OllamaToolsResult.ToolResult r : ollamaAPI.generateWithTools(modelName, new Tools.PromptBuilder()
|
for (OllamaToolsResult.ToolResult r : ollama.generateWithTools(modelName, new Tools.PromptBuilder()
|
||||||
.withToolSpecification(callSignFinderToolSpec)
|
.withToolSpecification(callSignFinderToolSpec)
|
||||||
.withPrompt(prompt1)
|
.withPrompt(prompt1)
|
||||||
.build(), new OptionsBuilder().build()).getToolResults()) {
|
.build(), new OptionsBuilder().build()).getToolResults()) {
|
||||||
AirlineDetail airlineDetail = (AirlineDetail) r.getResult();
|
AirlineDetail airlineDetail = (AirlineDetail) r.getResult();
|
||||||
System.out.println(String.format("[Result of tool '%s']: Call-sign of %s is '%s'! ✈️", r.getFunctionName(), airlineDetail.getName(), airlineDetail.getCallsign()));
|
System.out.println(String.format("[Result of tool '%s']: Call-sign of %s is '%s'! ✈️", r.getFunctionName(), airlineDetail.getName(), airlineDetail.getCallsign()));
|
||||||
}
|
|
||||||
|
|
||||||
String prompt2 = "I want to code name Astraeus as STARBOUND";
|
|
||||||
for (OllamaToolsResult.ToolResult r : ollamaAPI.generateWithTools(modelName, new Tools.PromptBuilder()
|
|
||||||
.withToolSpecification(callSignUpdaterToolSpec)
|
|
||||||
.withPrompt(prompt2)
|
|
||||||
.build(), new OptionsBuilder().build()).getToolResults()) {
|
|
||||||
Boolean updated = (Boolean) r.getResult();
|
|
||||||
System.out.println(String.format("[Result of tool '%s']: Call-sign is %s! ✈️", r.getFunctionName(), updated ? "updated" : "not updated"));
|
|
||||||
}
|
|
||||||
|
|
||||||
String prompt3 = "What is the call-sign of Astraeus?";
|
|
||||||
for (OllamaToolsResult.ToolResult r : ollamaAPI.generateWithTools(modelName, new Tools.PromptBuilder()
|
|
||||||
.withToolSpecification(callSignFinderToolSpec)
|
|
||||||
.withPrompt(prompt3)
|
|
||||||
.build(), new OptionsBuilder().build()).getToolResults()) {
|
|
||||||
AirlineDetail airlineDetail = (AirlineDetail) r.getResult();
|
|
||||||
System.out.println(String.format("[Result of tool '%s']: Call-sign of %s is '%s'! ✈️", r.getFunctionName(), airlineDetail.getName(), airlineDetail.getCallsign()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Tools.ToolSpecification getCallSignFinderToolSpec(Cluster cluster, String bucketName) {
|
String prompt2 = "I want to code name Astraeus as STARBOUND";
|
||||||
return Tools.ToolSpecification.builder()
|
for (OllamaToolsResult.ToolResult r : ollama.generateWithTools(modelName, new Tools.PromptBuilder()
|
||||||
.functionName("airline-lookup")
|
.withToolSpecification(callSignUpdaterToolSpec)
|
||||||
.functionDescription("You are a tool who finds only the airline name and do not worry about any other parameters. You simply find the airline name and ignore the rest of the parameters. Do not validate airline names as I want to use fake/fictitious airline names as well.")
|
.withPrompt(prompt2)
|
||||||
.toolFunction(new AirlineCallsignQueryToolFunction(bucketName, cluster))
|
.build(), new OptionsBuilder().build()).getToolResults()) {
|
||||||
.toolPrompt(
|
Boolean updated = (Boolean) r.getResult();
|
||||||
Tools.PromptFuncDefinition.builder()
|
System.out.println(String.format("[Result of tool '%s']: Call-sign is %s! ✈️", r.getFunctionName(), updated ? "updated" : "not updated"));
|
||||||
.type("prompt")
|
|
||||||
.function(
|
|
||||||
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
|
||||||
.name("get-airline-name")
|
|
||||||
.description("Get the airline name")
|
|
||||||
.parameters(
|
|
||||||
Tools.PromptFuncDefinition.Parameters.builder()
|
|
||||||
.type("object")
|
|
||||||
.properties(
|
|
||||||
Map.of(
|
|
||||||
"airlineName", Tools.PromptFuncDefinition.Property.builder()
|
|
||||||
.type("string")
|
|
||||||
.description("The name of the airline. e.g. Emirates")
|
|
||||||
.required(true)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.required(java.util.List.of("airline-name"))
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Tools.ToolSpecification getCallSignUpdaterToolSpec(Cluster cluster, String bucketName) {
|
String prompt3 = "What is the call-sign of Astraeus?";
|
||||||
return Tools.ToolSpecification.builder()
|
for (OllamaToolsResult.ToolResult r : ollama.generateWithTools(modelName, new Tools.PromptBuilder()
|
||||||
.functionName("airline-update")
|
.withToolSpecification(callSignFinderToolSpec)
|
||||||
.functionDescription("You are a tool who finds the airline name and its callsign and do not worry about any validations. You simply find the airline name and its callsign. Do not validate airline names as I want to use fake/fictitious airline names as well.")
|
.withPrompt(prompt3)
|
||||||
.toolFunction(new AirlineCallsignUpdateToolFunction(bucketName, cluster))
|
.build(), new OptionsBuilder().build()).getToolResults()) {
|
||||||
.toolPrompt(
|
AirlineDetail airlineDetail = (AirlineDetail) r.getResult();
|
||||||
Tools.PromptFuncDefinition.builder()
|
System.out.println(String.format("[Result of tool '%s']: Call-sign of %s is '%s'! ✈️", r.getFunctionName(), airlineDetail.getName(), airlineDetail.getCallsign()));
|
||||||
.type("prompt")
|
|
||||||
.function(
|
|
||||||
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
|
||||||
.name("get-airline-name-and-callsign")
|
|
||||||
.description("Get the airline name and callsign")
|
|
||||||
.parameters(
|
|
||||||
Tools.PromptFuncDefinition.Parameters.builder()
|
|
||||||
.type("object")
|
|
||||||
.properties(
|
|
||||||
Map.of(
|
|
||||||
"airlineName", Tools.PromptFuncDefinition.Property.builder()
|
|
||||||
.type("string")
|
|
||||||
.description("The name of the airline. e.g. Emirates")
|
|
||||||
.required(true)
|
|
||||||
.build(),
|
|
||||||
"airlineCallsign", Tools.PromptFuncDefinition.Property.builder()
|
|
||||||
.type("string")
|
|
||||||
.description("The callsign of the airline. e.g. Maverick")
|
|
||||||
.enumValues(Arrays.asList("petrol", "diesel"))
|
|
||||||
.required(true)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.required(java.util.List.of("airlineName", "airlineCallsign"))
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build();
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Tools.ToolSpecification getCallSignFinderToolSpec(Cluster cluster, String bucketName) {
|
||||||
|
return Tools.ToolSpecification.builder()
|
||||||
|
.functionName("airline-lookup")
|
||||||
|
.functionDescription("You are a tool who finds only the airline name and do not worry about any other parameters. You simply find the airline name and ignore the rest of the parameters. Do not validate airline names as I want to use fake/fictitious airline names as well.")
|
||||||
|
.toolFunction(new AirlineCallsignQueryToolFunction(bucketName, cluster))
|
||||||
|
.toolPrompt(
|
||||||
|
Tools.PromptFuncDefinition.builder()
|
||||||
|
.type("prompt")
|
||||||
|
.function(
|
||||||
|
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||||
|
.name("get-airline-name")
|
||||||
|
.description("Get the airline name")
|
||||||
|
.parameters(
|
||||||
|
Tools.PromptFuncDefinition.Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(
|
||||||
|
Map.of(
|
||||||
|
"airlineName", Tools.PromptFuncDefinition.Property.builder()
|
||||||
|
.type("string")
|
||||||
|
.description("The name of the airline. e.g. Emirates")
|
||||||
|
.required(true)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.required(java.util.List.of("airline-name"))
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Tools.ToolSpecification getCallSignUpdaterToolSpec(Cluster cluster, String bucketName) {
|
||||||
|
return Tools.ToolSpecification.builder()
|
||||||
|
.functionName("airline-update")
|
||||||
|
.functionDescription("You are a tool who finds the airline name and its callsign and do not worry about any validations. You simply find the airline name and its callsign. Do not validate airline names as I want to use fake/fictitious airline names as well.")
|
||||||
|
.toolFunction(new AirlineCallsignUpdateToolFunction(bucketName, cluster))
|
||||||
|
.toolPrompt(
|
||||||
|
Tools.PromptFuncDefinition.builder()
|
||||||
|
.type("prompt")
|
||||||
|
.function(
|
||||||
|
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||||
|
.name("get-airline-name-and-callsign")
|
||||||
|
.description("Get the airline name and callsign")
|
||||||
|
.parameters(
|
||||||
|
Tools.PromptFuncDefinition.Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(
|
||||||
|
Map.of(
|
||||||
|
"airlineName", Tools.PromptFuncDefinition.Property.builder()
|
||||||
|
.type("string")
|
||||||
|
.description("The name of the airline. e.g. Emirates")
|
||||||
|
.required(true)
|
||||||
|
.build(),
|
||||||
|
"airlineCallsign", Tools.PromptFuncDefinition.Property.builder()
|
||||||
|
.type("string")
|
||||||
|
.description("The callsign of the airline. e.g. Maverick")
|
||||||
|
.enumValues(Arrays.asList("petrol", "diesel"))
|
||||||
|
.required(true)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.required(java.util.List.of("airlineName", "airlineCallsign"))
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class AirlineCallsignQueryToolFunction implements ToolFunction {
|
class AirlineCallsignQueryToolFunction implements ToolFunction {
|
||||||
private final String bucketName;
|
private final String bucketName;
|
||||||
private final Cluster cluster;
|
private final Cluster cluster;
|
||||||
|
|
||||||
public AirlineCallsignQueryToolFunction(String bucketName, Cluster cluster) {
|
public AirlineCallsignQueryToolFunction(String bucketName, Cluster cluster) {
|
||||||
this.bucketName = bucketName;
|
this.bucketName = bucketName;
|
||||||
this.cluster = cluster;
|
this.cluster = cluster;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public AirlineDetail apply(Map<String, Object> arguments) {
|
public AirlineDetail apply(Map<String, Object> arguments) {
|
||||||
String airlineName = arguments.get("airlineName").toString();
|
String airlineName = arguments.get("airlineName").toString();
|
||||||
|
|
||||||
Bucket bucket = cluster.bucket(bucketName);
|
Bucket bucket = cluster.bucket(bucketName);
|
||||||
bucket.waitUntilReady(Duration.ofSeconds(10));
|
bucket.waitUntilReady(Duration.ofSeconds(10));
|
||||||
|
|
||||||
Scope inventoryScope = bucket.scope("inventory");
|
Scope inventoryScope = bucket.scope("inventory");
|
||||||
QueryResult result = inventoryScope.query(String.format("SELECT * FROM airline WHERE name = '%s';", airlineName));
|
QueryResult result = inventoryScope.query(String.format("SELECT * FROM airline WHERE name = '%s';", airlineName));
|
||||||
|
|
||||||
JsonObject row = (JsonObject) result.rowsAsObject().get(0).get("airline");
|
JsonObject row = (JsonObject) result.rowsAsObject().get(0).get("airline");
|
||||||
return new AirlineDetail(row.getString("callsign"), row.getString("name"), row.getString("country"));
|
return new AirlineDetail(row.getString("callsign"), row.getString("name"), row.getString("country"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class AirlineCallsignUpdateToolFunction implements ToolFunction {
|
class AirlineCallsignUpdateToolFunction implements ToolFunction {
|
||||||
private final String bucketName;
|
private final String bucketName;
|
||||||
private final Cluster cluster;
|
private final Cluster cluster;
|
||||||
|
|
||||||
public AirlineCallsignUpdateToolFunction(String bucketName, Cluster cluster) {
|
public AirlineCallsignUpdateToolFunction(String bucketName, Cluster cluster) {
|
||||||
this.bucketName = bucketName;
|
this.bucketName = bucketName;
|
||||||
this.cluster = cluster;
|
this.cluster = cluster;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Boolean apply(Map<String, Object> arguments) {
|
||||||
|
String airlineName = arguments.get("airlineName").toString();
|
||||||
|
String airlineNewCallsign = arguments.get("airlineCallsign").toString();
|
||||||
|
|
||||||
|
Bucket bucket = cluster.bucket(bucketName);
|
||||||
|
bucket.waitUntilReady(Duration.ofSeconds(10));
|
||||||
|
|
||||||
|
Scope inventoryScope = bucket.scope("inventory");
|
||||||
|
String query = String.format("SELECT * FROM airline WHERE name = '%s';", airlineName);
|
||||||
|
|
||||||
|
QueryResult result;
|
||||||
|
try {
|
||||||
|
result = inventoryScope.query(query);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException("Error executing query", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (result.rowsAsObject().isEmpty()) {
|
||||||
@Override
|
throw new RuntimeException("Airline not found with name: " + airlineName);
|
||||||
public Boolean apply(Map<String, Object> arguments) {
|
|
||||||
String airlineName = arguments.get("airlineName").toString();
|
|
||||||
String airlineNewCallsign = arguments.get("airlineCallsign").toString();
|
|
||||||
|
|
||||||
Bucket bucket = cluster.bucket(bucketName);
|
|
||||||
bucket.waitUntilReady(Duration.ofSeconds(10));
|
|
||||||
|
|
||||||
Scope inventoryScope = bucket.scope("inventory");
|
|
||||||
String query = String.format("SELECT * FROM airline WHERE name = '%s';", airlineName);
|
|
||||||
|
|
||||||
QueryResult result;
|
|
||||||
try {
|
|
||||||
result = inventoryScope.query(query);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException("Error executing query", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (result.rowsAsObject().isEmpty()) {
|
|
||||||
throw new RuntimeException("Airline not found with name: " + airlineName);
|
|
||||||
}
|
|
||||||
|
|
||||||
JsonObject row = (JsonObject) result.rowsAsObject().get(0).get("airline");
|
|
||||||
|
|
||||||
if (row == null) {
|
|
||||||
throw new RuntimeException("Airline data is missing or corrupted.");
|
|
||||||
}
|
|
||||||
|
|
||||||
String currentCallsign = row.getString("callsign");
|
|
||||||
|
|
||||||
if (!airlineNewCallsign.equals(currentCallsign)) {
|
|
||||||
JsonObject updateQuery = JsonObject.create()
|
|
||||||
.put("callsign", airlineNewCallsign);
|
|
||||||
|
|
||||||
inventoryScope.query(String.format(
|
|
||||||
"UPDATE airline SET callsign = '%s' WHERE name = '%s';",
|
|
||||||
airlineNewCallsign, airlineName
|
|
||||||
));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
JsonObject row = (JsonObject) result.rowsAsObject().get(0).get("airline");
|
||||||
|
|
||||||
|
if (row == null) {
|
||||||
|
throw new RuntimeException("Airline data is missing or corrupted.");
|
||||||
|
}
|
||||||
|
|
||||||
|
String currentCallsign = row.getString("callsign");
|
||||||
|
|
||||||
|
if (!airlineNewCallsign.equals(currentCallsign)) {
|
||||||
|
JsonObject updateQuery = JsonObject.create()
|
||||||
|
.put("callsign", airlineNewCallsign);
|
||||||
|
|
||||||
|
inventoryScope.query(String.format(
|
||||||
|
"UPDATE airline SET callsign = '%s' WHERE name = '%s';",
|
||||||
|
airlineNewCallsign, airlineName
|
||||||
|
));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("ALL")
|
@SuppressWarnings("ALL")
|
||||||
@ -567,9 +568,9 @@ class AirlineCallsignUpdateToolFunction implements ToolFunction {
|
|||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
class AirlineDetail {
|
class AirlineDetail {
|
||||||
private String callsign;
|
private String callsign;
|
||||||
private String name;
|
private String name;
|
||||||
private String country;
|
private String country;
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -578,9 +579,9 @@ class AirlineDetail {
|
|||||||
#### 1. Ollama API Client Setup
|
#### 1. Ollama API Client Setup
|
||||||
|
|
||||||
```javascript
|
```javascript
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
OllamaAPI ollama = new OllamaAPI(host);
|
||||||
|
|
||||||
ollamaAPI.setRequestTimeoutSeconds(60);
|
ollama.setRequestTimeoutSeconds(60);
|
||||||
```
|
```
|
||||||
|
|
||||||
Here, we initialize the Ollama API client and configure it with the host of the Ollama server, where the model is hosted
|
Here, we initialize the Ollama API client and configure it with the host of the Ollama server, where the model is hosted
|
||||||
@ -595,7 +596,7 @@ queries the database for airline details based on the airline name.
|
|||||||
```javascript
|
```javascript
|
||||||
Tools.ToolSpecification callSignFinderToolSpec = getCallSignFinderToolSpec(cluster, bucketName);
|
Tools.ToolSpecification callSignFinderToolSpec = getCallSignFinderToolSpec(cluster, bucketName);
|
||||||
|
|
||||||
ollamaAPI.registerTool(callSignFinderToolSpec);
|
ollama.registerTool(callSignFinderToolSpec);
|
||||||
```
|
```
|
||||||
|
|
||||||
This step registers custom tools with Ollama that allows the tool-calling model to invoke database queries.
|
This step registers custom tools with Ollama that allows the tool-calling model to invoke database queries.
|
||||||
@ -619,7 +620,7 @@ String prompt = "What is the call-sign of Astraeus?";
|
|||||||
#### 5. Generating Results with Tools
|
#### 5. Generating Results with Tools
|
||||||
|
|
||||||
```javascript
|
```javascript
|
||||||
for (OllamaToolsResult.ToolResult r : ollamaAPI.generateWithTools(modelName, new Tools.PromptBuilder()
|
for (OllamaToolsResult.ToolResult r : ollama.generateWithTools(modelName, new Tools.PromptBuilder()
|
||||||
.withToolSpecification(callSignFinderToolSpec)
|
.withToolSpecification(callSignFinderToolSpec)
|
||||||
.withPrompt(prompt)
|
.withPrompt(prompt)
|
||||||
.build(), new OptionsBuilder().build()).getToolResults()) {
|
.build(), new OptionsBuilder().build()).getToolResults()) {
|
||||||
@ -649,7 +650,7 @@ then update the airline’s callsign.
|
|||||||
```javascript
|
```javascript
|
||||||
Tools.ToolSpecification callSignUpdaterToolSpec = getCallSignUpdaterToolSpec(cluster, bucketName);
|
Tools.ToolSpecification callSignUpdaterToolSpec = getCallSignUpdaterToolSpec(cluster, bucketName);
|
||||||
|
|
||||||
ollamaAPI.registerTool(callSignUpdaterToolSpec);
|
ollama.registerTool(callSignUpdaterToolSpec);
|
||||||
```
|
```
|
||||||
|
|
||||||
The tool will execute a Couchbase N1QL query to update the airline’s callsign.
|
The tool will execute a Couchbase N1QL query to update the airline’s callsign.
|
||||||
@ -671,7 +672,7 @@ And then we invoke the model with the new prompt.
|
|||||||
|
|
||||||
```javascript
|
```javascript
|
||||||
String prompt = "I want to code name Astraeus as STARBOUND";
|
String prompt = "I want to code name Astraeus as STARBOUND";
|
||||||
for (OllamaToolsResult.ToolResult r : ollamaAPI.generateWithTools(modelName, new Tools.PromptBuilder()
|
for (OllamaToolsResult.ToolResult r : ollama.generateWithTools(modelName, new Tools.PromptBuilder()
|
||||||
.withToolSpecification(callSignUpdaterToolSpec)
|
.withToolSpecification(callSignUpdaterToolSpec)
|
||||||
.withPrompt(prompt)
|
.withPrompt(prompt)
|
||||||
.build(), new OptionsBuilder().build()).getToolResults()) {
|
.build(), new OptionsBuilder().build()).getToolResults()) {
|
||||||
|
@ -10,7 +10,7 @@ Ollama server would be setup behind a gateway/reverse proxy with basic auth.
|
|||||||
After configuring basic authentication, all subsequent requests will include the Basic Auth header.
|
After configuring basic authentication, all subsequent requests will include the Basic Auth header.
|
||||||
|
|
||||||
```java
|
```java
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.Ollama;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
|
|
||||||
@ -18,9 +18,9 @@ public class Main {
|
|||||||
|
|
||||||
String host = "http://localhost:11434/";
|
String host = "http://localhost:11434/";
|
||||||
|
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
Ollama ollama = new Ollama(host);
|
||||||
|
|
||||||
ollamaAPI.setBasicAuth("username", "password");
|
ollama.setBasicAuth("username", "password");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
@ -10,7 +10,7 @@ Ollama server would be setup behind a gateway/reverse proxy with bearer auth.
|
|||||||
After configuring bearer authentication, all subsequent requests will include the Bearer Auth header.
|
After configuring bearer authentication, all subsequent requests will include the Bearer Auth header.
|
||||||
|
|
||||||
```java
|
```java
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.Ollama;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
|
|
||||||
@ -18,9 +18,9 @@ public class Main {
|
|||||||
|
|
||||||
String host = "http://localhost:11434/";
|
String host = "http://localhost:11434/";
|
||||||
|
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
Ollama ollama = new Ollama(host);
|
||||||
|
|
||||||
ollamaAPI.setBearerAuth("YOUR-TOKEN");
|
ollama.setBearerAuth("YOUR-TOKEN");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
@ -36,7 +36,7 @@ from [javadoc](https://ollama4j.github.io/ollama4j/apidocs/io/github/ollama4j/ol
|
|||||||
## Build an empty `Options` object
|
## Build an empty `Options` object
|
||||||
|
|
||||||
```java
|
```java
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.utils.Options;
|
import io.github.ollama4j.utils.Options;
|
||||||
import io.github.ollama4j.utils.OptionsBuilder;
|
import io.github.ollama4j.utils.OptionsBuilder;
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ public class Main {
|
|||||||
|
|
||||||
String host = "http://localhost:11434/";
|
String host = "http://localhost:11434/";
|
||||||
|
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
Ollama ollama = new Ollama(host);
|
||||||
|
|
||||||
Options options = new OptionsBuilder().build();
|
Options options = new OptionsBuilder().build();
|
||||||
}
|
}
|
||||||
@ -65,7 +65,7 @@ public class Main {
|
|||||||
|
|
||||||
String host = "http://localhost:11434/";
|
String host = "http://localhost:11434/";
|
||||||
|
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
OllamaAPI ollama = new OllamaAPI(host);
|
||||||
|
|
||||||
Options options =
|
Options options =
|
||||||
new OptionsBuilder()
|
new OptionsBuilder()
|
||||||
|
@ -7,6 +7,7 @@ sidebar_position: 5
|
|||||||
This API lets you check the reachability of Ollama server.
|
This API lets you check the reachability of Ollama server.
|
||||||
|
|
||||||
```java
|
```java
|
||||||
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
@ -14,9 +15,9 @@ public class Main {
|
|||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
String host = "http://localhost:11434/";
|
String host = "http://localhost:11434/";
|
||||||
|
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
Ollama ollama = new Ollama(host);
|
||||||
|
|
||||||
ollamaAPI.ping();
|
ollama.ping();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
@ -8,6 +8,7 @@ This is designed for prompt engineering. It allows you to easily build the promp
|
|||||||
inferences.
|
inferences.
|
||||||
|
|
||||||
```java
|
```java
|
||||||
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
import io.github.ollama4j.models.response.OllamaResult;
|
import io.github.ollama4j.models.response.OllamaResult;
|
||||||
import io.github.ollama4j.types.OllamaModelType;
|
import io.github.ollama4j.types.OllamaModelType;
|
||||||
@ -18,8 +19,8 @@ public class Main {
|
|||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
|
|
||||||
String host = "http://localhost:11434/";
|
String host = "http://localhost:11434/";
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
Ollama ollama = new Ollama(host);
|
||||||
ollamaAPI.setRequestTimeoutSeconds(10);
|
ollama.setRequestTimeoutSeconds(10);
|
||||||
|
|
||||||
String model = OllamaModelType.PHI;
|
String model = OllamaModelType.PHI;
|
||||||
|
|
||||||
@ -43,7 +44,7 @@ public class Main {
|
|||||||
.add("How do I read a file in Go and print its contents to stdout?");
|
.add("How do I read a file in Go and print its contents to stdout?");
|
||||||
|
|
||||||
boolean raw = false;
|
boolean raw = false;
|
||||||
OllamaResult response = ollamaAPI.generate(model, promptBuilder.build(), raw, new OptionsBuilder().build());
|
OllamaResult response = ollama.generate(model, promptBuilder.build(), raw, new OptionsBuilder().build());
|
||||||
System.out.println(response.getResponse());
|
System.out.println(response.getResponse());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,15 +11,15 @@ This API corresponds to the [PS](https://github.com/ollama/ollama/blob/main/docs
|
|||||||
```java
|
```java
|
||||||
package io.github.ollama4j.localtests;
|
package io.github.ollama4j.localtests;
|
||||||
|
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.models.ps.ModelProcessesResult;
|
import io.github.ollama4j.models.ps.ModelProcessesResult;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
|
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI("http://localhost:11434");
|
Ollama ollama = new Ollama("http://localhost:11434");
|
||||||
|
|
||||||
ModelProcessesResult response = ollamaAPI.ps();
|
ModelProcessesResult response = ollama.ps();
|
||||||
|
|
||||||
System.out.println(response);
|
System.out.println(response);
|
||||||
}
|
}
|
||||||
|
@ -9,17 +9,18 @@ sidebar_position: 2
|
|||||||
This API lets you set the request timeout for the Ollama client.
|
This API lets you set the request timeout for the Ollama client.
|
||||||
|
|
||||||
```java
|
```java
|
||||||
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
|
|
||||||
String host = "http://localhost:11434/";
|
String host = "http://localhost:11434/";
|
||||||
|
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
Ollama ollama = new Ollama(host);
|
||||||
|
|
||||||
ollamaAPI.setRequestTimeoutSeconds(10);
|
ollama.setRequestTimeoutSeconds(10);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
@ -16,6 +16,7 @@ _Base roles are `SYSTEM`, `USER`, `ASSISTANT`, `TOOL`._
|
|||||||
#### Add new role
|
#### Add new role
|
||||||
|
|
||||||
```java
|
```java
|
||||||
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||||
|
|
||||||
@ -23,9 +24,9 @@ public class Main {
|
|||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
String host = "http://localhost:11434/";
|
String host = "http://localhost:11434/";
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
Ollama ollama = new Ollama(host);
|
||||||
|
|
||||||
OllamaChatMessageRole customRole = ollamaAPI.addCustomRole("custom-role");
|
OllamaChatMessageRole customRole = ollama.addCustomRole("custom-role");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@ -33,16 +34,16 @@ public class Main {
|
|||||||
#### List roles
|
#### List roles
|
||||||
|
|
||||||
```java
|
```java
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
String host = "http://localhost:11434/";
|
String host = "http://localhost:11434/";
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
Ollama ollama = new Ollama(host);
|
||||||
|
|
||||||
List<OllamaChatMessageRole> roles = ollamaAPI.listRoles();
|
List<OllamaChatMessageRole> roles = ollama.listRoles();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@ -50,6 +51,7 @@ public class Main {
|
|||||||
#### Get role
|
#### Get role
|
||||||
|
|
||||||
```java
|
```java
|
||||||
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||||
|
|
||||||
@ -57,9 +59,9 @@ public class Main {
|
|||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
String host = "http://localhost:11434/";
|
String host = "http://localhost:11434/";
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
Ollama ollama = new Ollama(host);
|
||||||
|
|
||||||
List<OllamaChatMessageRole> roles = ollamaAPI.getRole("custom-role");
|
List<OllamaChatMessageRole> roles = ollama.getRole("custom-role");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
@ -112,14 +112,15 @@ or use other suitable implementations.
|
|||||||
Create a new Java class in your project and add this code.
|
Create a new Java class in your project and add this code.
|
||||||
|
|
||||||
```java
|
```java
|
||||||
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
|
|
||||||
public class OllamaAPITest {
|
public class OllamaAPITest {
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI();
|
Ollama ollama = new Ollama();
|
||||||
|
|
||||||
boolean isOllamaServerReachable = ollamaAPI.ping();
|
boolean isOllamaServerReachable = ollama.ping();
|
||||||
|
|
||||||
System.out.println("Is Ollama server running: " + isOllamaServerReachable);
|
System.out.println("Is Ollama server running: " + isOllamaServerReachable);
|
||||||
}
|
}
|
||||||
@ -130,6 +131,7 @@ This uses the default Ollama host as `http://localhost:11434`.
|
|||||||
Specify a different Ollama host that you want to connect to.
|
Specify a different Ollama host that you want to connect to.
|
||||||
|
|
||||||
```java
|
```java
|
||||||
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
|
|
||||||
public class OllamaAPITest {
|
public class OllamaAPITest {
|
||||||
@ -137,9 +139,9 @@ public class OllamaAPITest {
|
|||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
String host = "http://localhost:11434/";
|
String host = "http://localhost:11434/";
|
||||||
|
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
Ollama ollama = new Ollama(host);
|
||||||
|
|
||||||
boolean isOllamaServerReachable = ollamaAPI.ping();
|
boolean isOllamaServerReachable = ollama.ping();
|
||||||
|
|
||||||
System.out.println("Is Ollama server running: " + isOllamaServerReachable);
|
System.out.println("Is Ollama server running: " + isOllamaServerReachable);
|
||||||
}
|
}
|
||||||
|
@ -53,9 +53,9 @@ import org.slf4j.LoggerFactory;
|
|||||||
* <p>This class provides methods for model management, chat, embeddings, tool registration, and more.
|
* <p>This class provides methods for model management, chat, embeddings, tool registration, and more.
|
||||||
*/
|
*/
|
||||||
@SuppressWarnings({"DuplicatedCode", "resource", "SpellCheckingInspection"})
|
@SuppressWarnings({"DuplicatedCode", "resource", "SpellCheckingInspection"})
|
||||||
public class OllamaAPI {
|
public class Ollama {
|
||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
|
private static final Logger LOG = LoggerFactory.getLogger(Ollama.class);
|
||||||
|
|
||||||
private final String host;
|
private final String host;
|
||||||
private Auth auth;
|
private Auth auth;
|
||||||
@ -107,7 +107,7 @@ public class OllamaAPI {
|
|||||||
/**
|
/**
|
||||||
* Instantiates the Ollama API with the default Ollama host: {@code http://localhost:11434}
|
* Instantiates the Ollama API with the default Ollama host: {@code http://localhost:11434}
|
||||||
*/
|
*/
|
||||||
public OllamaAPI() {
|
public Ollama() {
|
||||||
this.host = "http://localhost:11434";
|
this.host = "http://localhost:11434";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,7 +116,7 @@ public class OllamaAPI {
|
|||||||
*
|
*
|
||||||
* @param host the host address of the Ollama server
|
* @param host the host address of the Ollama server
|
||||||
*/
|
*/
|
||||||
public OllamaAPI(String host) {
|
public Ollama(String host) {
|
||||||
if (host.endsWith("/")) {
|
if (host.endsWith("/")) {
|
||||||
this.host = host.substring(0, host.length() - 1);
|
this.host = host.substring(0, host.length() - 1);
|
||||||
} else {
|
} else {
|
@ -8,7 +8,7 @@
|
|||||||
*/
|
*/
|
||||||
package io.github.ollama4j.tools.annotations;
|
package io.github.ollama4j.tools.annotations;
|
||||||
|
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.Ollama;
|
||||||
import java.lang.annotation.ElementType;
|
import java.lang.annotation.ElementType;
|
||||||
import java.lang.annotation.Retention;
|
import java.lang.annotation.Retention;
|
||||||
import java.lang.annotation.RetentionPolicy;
|
import java.lang.annotation.RetentionPolicy;
|
||||||
@ -18,7 +18,7 @@ import java.lang.annotation.Target;
|
|||||||
* Annotation to mark a class as an Ollama tool service.
|
* Annotation to mark a class as an Ollama tool service.
|
||||||
* <p>
|
* <p>
|
||||||
* When a class is annotated with {@code @OllamaToolService}, the method
|
* When a class is annotated with {@code @OllamaToolService}, the method
|
||||||
* {@link OllamaAPI#registerAnnotatedTools()} can be used to automatically register all tool provider
|
* {@link Ollama#registerAnnotatedTools()} can be used to automatically register all tool provider
|
||||||
* classes specified in the {@link #providers()} array. All methods in those provider classes that are
|
* classes specified in the {@link #providers()} array. All methods in those provider classes that are
|
||||||
* annotated with {@link ToolSpec} will be registered as tools.
|
* annotated with {@link ToolSpec} will be registered as tools.
|
||||||
* </p>
|
* </p>
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
*/
|
*/
|
||||||
package io.github.ollama4j.tools.annotations;
|
package io.github.ollama4j.tools.annotations;
|
||||||
|
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.Ollama;
|
||||||
import java.lang.annotation.ElementType;
|
import java.lang.annotation.ElementType;
|
||||||
import java.lang.annotation.Retention;
|
import java.lang.annotation.Retention;
|
||||||
import java.lang.annotation.RetentionPolicy;
|
import java.lang.annotation.RetentionPolicy;
|
||||||
@ -16,7 +16,7 @@ import java.lang.annotation.Target;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Annotation to mark a method as a tool that can be registered automatically by
|
* Annotation to mark a method as a tool that can be registered automatically by
|
||||||
* {@link OllamaAPI#registerAnnotatedTools()}.
|
* {@link Ollama#registerAnnotatedTools()}.
|
||||||
* <p>
|
* <p>
|
||||||
* Methods annotated with {@code @ToolSpec} will be discovered and registered as tools
|
* Methods annotated with {@code @ToolSpec} will be discovered and registered as tools
|
||||||
* when the containing class is specified as a provider in {@link OllamaToolService}.
|
* when the containing class is specified as a provider in {@link OllamaToolService}.
|
||||||
|
@ -10,7 +10,7 @@ package io.github.ollama4j.integrationtests;
|
|||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.exceptions.OllamaException;
|
import io.github.ollama4j.exceptions.OllamaException;
|
||||||
import io.github.ollama4j.impl.ConsoleOutputChatTokenHandler;
|
import io.github.ollama4j.impl.ConsoleOutputChatTokenHandler;
|
||||||
import io.github.ollama4j.impl.ConsoleOutputGenerateTokenHandler;
|
import io.github.ollama4j.impl.ConsoleOutputGenerateTokenHandler;
|
||||||
@ -44,11 +44,11 @@ import org.testcontainers.ollama.OllamaContainer;
|
|||||||
@OllamaToolService(providers = {AnnotatedTool.class})
|
@OllamaToolService(providers = {AnnotatedTool.class})
|
||||||
@TestMethodOrder(OrderAnnotation.class)
|
@TestMethodOrder(OrderAnnotation.class)
|
||||||
@SuppressWarnings({"HttpUrlsUsage", "SpellCheckingInspection", "FieldCanBeLocal", "ConstantValue"})
|
@SuppressWarnings({"HttpUrlsUsage", "SpellCheckingInspection", "FieldCanBeLocal", "ConstantValue"})
|
||||||
class OllamaAPIIntegrationTest {
|
class OllamaIntegrationTest {
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPIIntegrationTest.class);
|
private static final Logger LOG = LoggerFactory.getLogger(OllamaIntegrationTest.class);
|
||||||
|
|
||||||
private static OllamaContainer ollama;
|
private static OllamaContainer ollama;
|
||||||
private static OllamaAPI api;
|
private static Ollama api;
|
||||||
|
|
||||||
private static final String EMBEDDING_MODEL = "all-minilm";
|
private static final String EMBEDDING_MODEL = "all-minilm";
|
||||||
private static final String VISION_MODEL = "moondream:1.8b";
|
private static final String VISION_MODEL = "moondream:1.8b";
|
||||||
@ -81,7 +81,7 @@ class OllamaAPIIntegrationTest {
|
|||||||
Properties props = new Properties();
|
Properties props = new Properties();
|
||||||
try {
|
try {
|
||||||
props.load(
|
props.load(
|
||||||
OllamaAPIIntegrationTest.class
|
OllamaIntegrationTest.class
|
||||||
.getClassLoader()
|
.getClassLoader()
|
||||||
.getResourceAsStream("test-config.properties"));
|
.getResourceAsStream("test-config.properties"));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@ -103,7 +103,7 @@ class OllamaAPIIntegrationTest {
|
|||||||
|
|
||||||
if (useExternalOllamaHost) {
|
if (useExternalOllamaHost) {
|
||||||
LOG.info("Using external Ollama host: {}", ollamaHost);
|
LOG.info("Using external Ollama host: {}", ollamaHost);
|
||||||
api = new OllamaAPI(ollamaHost);
|
api = new Ollama(ollamaHost);
|
||||||
} else {
|
} else {
|
||||||
throw new RuntimeException(
|
throw new RuntimeException(
|
||||||
"USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers"
|
"USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers"
|
||||||
@ -124,7 +124,7 @@ class OllamaAPIIntegrationTest {
|
|||||||
ollama.start();
|
ollama.start();
|
||||||
LOG.info("Using Testcontainer Ollama host...");
|
LOG.info("Using Testcontainer Ollama host...");
|
||||||
api =
|
api =
|
||||||
new OllamaAPI(
|
new Ollama(
|
||||||
"http://"
|
"http://"
|
||||||
+ ollama.getHost()
|
+ ollama.getHost()
|
||||||
+ ":"
|
+ ":"
|
||||||
@ -143,8 +143,8 @@ class OllamaAPIIntegrationTest {
|
|||||||
@Test
|
@Test
|
||||||
@Order(1)
|
@Order(1)
|
||||||
void shouldThrowConnectExceptionForWrongEndpoint() {
|
void shouldThrowConnectExceptionForWrongEndpoint() {
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
|
Ollama ollama = new Ollama("http://wrong-host:11434");
|
||||||
assertThrows(OllamaException.class, ollamaAPI::listModels);
|
assertThrows(OllamaException.class, ollama::listModels);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -778,7 +778,7 @@ class OllamaAPIIntegrationTest {
|
|||||||
Collections.emptyList(),
|
Collections.emptyList(),
|
||||||
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
|
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
|
||||||
.build();
|
.build();
|
||||||
api.registerAnnotatedTools(new OllamaAPIIntegrationTest());
|
api.registerAnnotatedTools(new OllamaIntegrationTest());
|
||||||
|
|
||||||
OllamaChatResult chatResult = api.chat(requestModel, null);
|
OllamaChatResult chatResult = api.chat(requestModel, null);
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
@ -10,7 +10,7 @@ package io.github.ollama4j.integrationtests;
|
|||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.exceptions.OllamaException;
|
import io.github.ollama4j.exceptions.OllamaException;
|
||||||
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
|
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
|
||||||
import io.github.ollama4j.models.generate.OllamaGenerateRequestBuilder;
|
import io.github.ollama4j.models.generate.OllamaGenerateRequestBuilder;
|
||||||
@ -62,7 +62,7 @@ public class WithAuth {
|
|||||||
|
|
||||||
private static OllamaContainer ollama;
|
private static OllamaContainer ollama;
|
||||||
private static GenericContainer<?> nginx;
|
private static GenericContainer<?> nginx;
|
||||||
private static OllamaAPI api;
|
private static Ollama api;
|
||||||
|
|
||||||
@BeforeAll
|
@BeforeAll
|
||||||
static void setUp() {
|
static void setUp() {
|
||||||
@ -74,7 +74,7 @@ public class WithAuth {
|
|||||||
|
|
||||||
LOG.info("Using Testcontainer Ollama host...");
|
LOG.info("Using Testcontainer Ollama host...");
|
||||||
|
|
||||||
api = new OllamaAPI("http://" + nginx.getHost() + ":" + nginx.getMappedPort(NGINX_PORT));
|
api = new Ollama("http://" + nginx.getHost() + ":" + nginx.getMappedPort(NGINX_PORT));
|
||||||
api.setRequestTimeoutSeconds(120);
|
api.setRequestTimeoutSeconds(120);
|
||||||
api.setNumberOfRetriesForModelPull(3);
|
api.setNumberOfRetriesForModelPull(3);
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||||||
import static org.junit.jupiter.api.Assertions.fail;
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
import static org.mockito.Mockito.*;
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.Ollama;
|
||||||
import io.github.ollama4j.exceptions.OllamaException;
|
import io.github.ollama4j.exceptions.OllamaException;
|
||||||
import io.github.ollama4j.exceptions.RoleNotFoundException;
|
import io.github.ollama4j.exceptions.RoleNotFoundException;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||||
@ -36,12 +36,12 @@ import org.mockito.Mockito;
|
|||||||
class TestMockedAPIs {
|
class TestMockedAPIs {
|
||||||
@Test
|
@Test
|
||||||
void testPullModel() {
|
void testPullModel() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
String model = "llama2";
|
String model = "llama2";
|
||||||
try {
|
try {
|
||||||
doNothing().when(ollamaAPI).pullModel(model);
|
doNothing().when(ollama).pullModel(model);
|
||||||
ollamaAPI.pullModel(model);
|
ollama.pullModel(model);
|
||||||
verify(ollamaAPI, times(1)).pullModel(model);
|
verify(ollama, times(1)).pullModel(model);
|
||||||
} catch (OllamaException e) {
|
} catch (OllamaException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
@ -49,11 +49,11 @@ class TestMockedAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testListModels() {
|
void testListModels() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
|
when(ollama.listModels()).thenReturn(new ArrayList<>());
|
||||||
ollamaAPI.listModels();
|
ollama.listModels();
|
||||||
verify(ollamaAPI, times(1)).listModels();
|
verify(ollama, times(1)).listModels();
|
||||||
} catch (OllamaException e) {
|
} catch (OllamaException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
@ -61,7 +61,7 @@ class TestMockedAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testCreateModel() {
|
void testCreateModel() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
CustomModelRequest customModelRequest =
|
CustomModelRequest customModelRequest =
|
||||||
CustomModelRequest.builder()
|
CustomModelRequest.builder()
|
||||||
.model("mario")
|
.model("mario")
|
||||||
@ -69,9 +69,9 @@ class TestMockedAPIs {
|
|||||||
.system("You are Mario from Super Mario Bros.")
|
.system("You are Mario from Super Mario Bros.")
|
||||||
.build();
|
.build();
|
||||||
try {
|
try {
|
||||||
doNothing().when(ollamaAPI).createModel(customModelRequest);
|
doNothing().when(ollama).createModel(customModelRequest);
|
||||||
ollamaAPI.createModel(customModelRequest);
|
ollama.createModel(customModelRequest);
|
||||||
verify(ollamaAPI, times(1)).createModel(customModelRequest);
|
verify(ollama, times(1)).createModel(customModelRequest);
|
||||||
} catch (OllamaException e) {
|
} catch (OllamaException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
@ -79,12 +79,12 @@ class TestMockedAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testDeleteModel() {
|
void testDeleteModel() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
String model = "llama2";
|
String model = "llama2";
|
||||||
try {
|
try {
|
||||||
doNothing().when(ollamaAPI).deleteModel(model, true);
|
doNothing().when(ollama).deleteModel(model, true);
|
||||||
ollamaAPI.deleteModel(model, true);
|
ollama.deleteModel(model, true);
|
||||||
verify(ollamaAPI, times(1)).deleteModel(model, true);
|
verify(ollama, times(1)).deleteModel(model, true);
|
||||||
} catch (OllamaException e) {
|
} catch (OllamaException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
@ -92,12 +92,12 @@ class TestMockedAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testGetModelDetails() {
|
void testGetModelDetails() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
String model = "llama2";
|
String model = "llama2";
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
|
when(ollama.getModelDetails(model)).thenReturn(new ModelDetail());
|
||||||
ollamaAPI.getModelDetails(model);
|
ollama.getModelDetails(model);
|
||||||
verify(ollamaAPI, times(1)).getModelDetails(model);
|
verify(ollama, times(1)).getModelDetails(model);
|
||||||
} catch (OllamaException e) {
|
} catch (OllamaException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
@ -105,16 +105,16 @@ class TestMockedAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testGenerateEmbeddings() {
|
void testGenerateEmbeddings() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
String model = "llama2";
|
String model = "llama2";
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
try {
|
try {
|
||||||
OllamaEmbedRequest m = new OllamaEmbedRequest();
|
OllamaEmbedRequest m = new OllamaEmbedRequest();
|
||||||
m.setModel(model);
|
m.setModel(model);
|
||||||
m.setInput(List.of(prompt));
|
m.setInput(List.of(prompt));
|
||||||
when(ollamaAPI.embed(m)).thenReturn(new OllamaEmbedResult());
|
when(ollama.embed(m)).thenReturn(new OllamaEmbedResult());
|
||||||
ollamaAPI.embed(m);
|
ollama.embed(m);
|
||||||
verify(ollamaAPI, times(1)).embed(m);
|
verify(ollama, times(1)).embed(m);
|
||||||
} catch (OllamaException e) {
|
} catch (OllamaException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
@ -122,14 +122,14 @@ class TestMockedAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testEmbed() {
|
void testEmbed() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
String model = "llama2";
|
String model = "llama2";
|
||||||
List<String> inputs = List.of("some prompt text");
|
List<String> inputs = List.of("some prompt text");
|
||||||
try {
|
try {
|
||||||
OllamaEmbedRequest m = new OllamaEmbedRequest(model, inputs);
|
OllamaEmbedRequest m = new OllamaEmbedRequest(model, inputs);
|
||||||
when(ollamaAPI.embed(m)).thenReturn(new OllamaEmbedResult());
|
when(ollama.embed(m)).thenReturn(new OllamaEmbedResult());
|
||||||
ollamaAPI.embed(m);
|
ollama.embed(m);
|
||||||
verify(ollamaAPI, times(1)).embed(m);
|
verify(ollama, times(1)).embed(m);
|
||||||
} catch (OllamaException e) {
|
} catch (OllamaException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
@ -137,14 +137,14 @@ class TestMockedAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testEmbedWithEmbedRequestModel() {
|
void testEmbedWithEmbedRequestModel() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
String model = "llama2";
|
String model = "llama2";
|
||||||
List<String> inputs = List.of("some prompt text");
|
List<String> inputs = List.of("some prompt text");
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.embed(new OllamaEmbedRequest(model, inputs)))
|
when(ollama.embed(new OllamaEmbedRequest(model, inputs)))
|
||||||
.thenReturn(new OllamaEmbedResult());
|
.thenReturn(new OllamaEmbedResult());
|
||||||
ollamaAPI.embed(new OllamaEmbedRequest(model, inputs));
|
ollama.embed(new OllamaEmbedRequest(model, inputs));
|
||||||
verify(ollamaAPI, times(1)).embed(new OllamaEmbedRequest(model, inputs));
|
verify(ollama, times(1)).embed(new OllamaEmbedRequest(model, inputs));
|
||||||
} catch (OllamaException e) {
|
} catch (OllamaException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
@ -152,7 +152,7 @@ class TestMockedAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAsk() {
|
void testAsk() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
String model = "llama2";
|
String model = "llama2";
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
OllamaGenerateStreamObserver observer = new OllamaGenerateStreamObserver(null, null);
|
OllamaGenerateStreamObserver observer = new OllamaGenerateStreamObserver(null, null);
|
||||||
@ -165,10 +165,9 @@ class TestMockedAPIs {
|
|||||||
.withThink(false)
|
.withThink(false)
|
||||||
.withStreaming(false)
|
.withStreaming(false)
|
||||||
.build();
|
.build();
|
||||||
when(ollamaAPI.generate(request, observer))
|
when(ollama.generate(request, observer)).thenReturn(new OllamaResult("", "", 0, 200));
|
||||||
.thenReturn(new OllamaResult("", "", 0, 200));
|
ollama.generate(request, observer);
|
||||||
ollamaAPI.generate(request, observer);
|
verify(ollama, times(1)).generate(request, observer);
|
||||||
verify(ollamaAPI, times(1)).generate(request, observer);
|
|
||||||
} catch (OllamaException e) {
|
} catch (OllamaException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
@ -176,7 +175,7 @@ class TestMockedAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAskWithImageFiles() {
|
void testAskWithImageFiles() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
String model = "llama2";
|
String model = "llama2";
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
try {
|
try {
|
||||||
@ -192,9 +191,9 @@ class TestMockedAPIs {
|
|||||||
.withFormat(null)
|
.withFormat(null)
|
||||||
.build();
|
.build();
|
||||||
OllamaGenerateStreamObserver handler = null;
|
OllamaGenerateStreamObserver handler = null;
|
||||||
when(ollamaAPI.generate(request, handler)).thenReturn(new OllamaResult("", "", 0, 200));
|
when(ollama.generate(request, handler)).thenReturn(new OllamaResult("", "", 0, 200));
|
||||||
ollamaAPI.generate(request, handler);
|
ollama.generate(request, handler);
|
||||||
verify(ollamaAPI, times(1)).generate(request, handler);
|
verify(ollama, times(1)).generate(request, handler);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
@ -202,7 +201,7 @@ class TestMockedAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAskWithImageURLs() {
|
void testAskWithImageURLs() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
String model = "llama2";
|
String model = "llama2";
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
try {
|
try {
|
||||||
@ -218,9 +217,9 @@ class TestMockedAPIs {
|
|||||||
.withFormat(null)
|
.withFormat(null)
|
||||||
.build();
|
.build();
|
||||||
OllamaGenerateStreamObserver handler = null;
|
OllamaGenerateStreamObserver handler = null;
|
||||||
when(ollamaAPI.generate(request, handler)).thenReturn(new OllamaResult("", "", 0, 200));
|
when(ollama.generate(request, handler)).thenReturn(new OllamaResult("", "", 0, 200));
|
||||||
ollamaAPI.generate(request, handler);
|
ollama.generate(request, handler);
|
||||||
verify(ollamaAPI, times(1)).generate(request, handler);
|
verify(ollama, times(1)).generate(request, handler);
|
||||||
} catch (OllamaException e) {
|
} catch (OllamaException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
@ -230,56 +229,55 @@ class TestMockedAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAskAsync() throws OllamaException {
|
void testAskAsync() throws OllamaException {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
String model = "llama2";
|
String model = "llama2";
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
when(ollamaAPI.generateAsync(model, prompt, false, false))
|
when(ollama.generateAsync(model, prompt, false, false))
|
||||||
.thenReturn(new OllamaAsyncResultStreamer(null, null, 3));
|
.thenReturn(new OllamaAsyncResultStreamer(null, null, 3));
|
||||||
ollamaAPI.generateAsync(model, prompt, false, false);
|
ollama.generateAsync(model, prompt, false, false);
|
||||||
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false, false);
|
verify(ollama, times(1)).generateAsync(model, prompt, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAddCustomRole() {
|
void testAddCustomRole() {
|
||||||
OllamaAPI ollamaAPI = mock(OllamaAPI.class);
|
Ollama ollama = mock(Ollama.class);
|
||||||
String roleName = "custom-role";
|
String roleName = "custom-role";
|
||||||
OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
|
OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
|
||||||
when(ollamaAPI.addCustomRole(roleName)).thenReturn(expectedRole);
|
when(ollama.addCustomRole(roleName)).thenReturn(expectedRole);
|
||||||
OllamaChatMessageRole customRole = ollamaAPI.addCustomRole(roleName);
|
OllamaChatMessageRole customRole = ollama.addCustomRole(roleName);
|
||||||
assertEquals(expectedRole, customRole);
|
assertEquals(expectedRole, customRole);
|
||||||
verify(ollamaAPI, times(1)).addCustomRole(roleName);
|
verify(ollama, times(1)).addCustomRole(roleName);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testListRoles() {
|
void testListRoles() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
Ollama ollama = Mockito.mock(Ollama.class);
|
||||||
OllamaChatMessageRole role1 = OllamaChatMessageRole.newCustomRole("role1");
|
OllamaChatMessageRole role1 = OllamaChatMessageRole.newCustomRole("role1");
|
||||||
OllamaChatMessageRole role2 = OllamaChatMessageRole.newCustomRole("role2");
|
OllamaChatMessageRole role2 = OllamaChatMessageRole.newCustomRole("role2");
|
||||||
List<OllamaChatMessageRole> expectedRoles = List.of(role1, role2);
|
List<OllamaChatMessageRole> expectedRoles = List.of(role1, role2);
|
||||||
when(ollamaAPI.listRoles()).thenReturn(expectedRoles);
|
when(ollama.listRoles()).thenReturn(expectedRoles);
|
||||||
List<OllamaChatMessageRole> actualRoles = ollamaAPI.listRoles();
|
List<OllamaChatMessageRole> actualRoles = ollama.listRoles();
|
||||||
assertEquals(expectedRoles, actualRoles);
|
assertEquals(expectedRoles, actualRoles);
|
||||||
verify(ollamaAPI, times(1)).listRoles();
|
verify(ollama, times(1)).listRoles();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testGetRoleNotFound() {
|
void testGetRoleNotFound() {
|
||||||
OllamaAPI ollamaAPI = mock(OllamaAPI.class);
|
Ollama ollama = mock(Ollama.class);
|
||||||
String roleName = "non-existing-role";
|
String roleName = "non-existing-role";
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.getRole(roleName))
|
when(ollama.getRole(roleName)).thenThrow(new RoleNotFoundException("Role not found"));
|
||||||
.thenThrow(new RoleNotFoundException("Role not found"));
|
|
||||||
} catch (RoleNotFoundException exception) {
|
} catch (RoleNotFoundException exception) {
|
||||||
throw new RuntimeException("Failed to run test: testGetRoleNotFound");
|
throw new RuntimeException("Failed to run test: testGetRoleNotFound");
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
ollamaAPI.getRole(roleName);
|
ollama.getRole(roleName);
|
||||||
fail("Expected RoleNotFoundException not thrown");
|
fail("Expected RoleNotFoundException not thrown");
|
||||||
} catch (RoleNotFoundException exception) {
|
} catch (RoleNotFoundException exception) {
|
||||||
assertEquals("Role not found", exception.getMessage());
|
assertEquals("Role not found", exception.getMessage());
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
verify(ollamaAPI, times(1)).getRole(roleName);
|
verify(ollama, times(1)).getRole(roleName);
|
||||||
} catch (RoleNotFoundException exception) {
|
} catch (RoleNotFoundException exception) {
|
||||||
throw new RuntimeException("Failed to run test: testGetRoleNotFound");
|
throw new RuntimeException("Failed to run test: testGetRoleNotFound");
|
||||||
}
|
}
|
||||||
@ -287,18 +285,18 @@ class TestMockedAPIs {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testGetRoleFound() {
|
void testGetRoleFound() {
|
||||||
OllamaAPI ollamaAPI = mock(OllamaAPI.class);
|
Ollama ollama = mock(Ollama.class);
|
||||||
String roleName = "existing-role";
|
String roleName = "existing-role";
|
||||||
OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
|
OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.getRole(roleName)).thenReturn(expectedRole);
|
when(ollama.getRole(roleName)).thenReturn(expectedRole);
|
||||||
} catch (RoleNotFoundException exception) {
|
} catch (RoleNotFoundException exception) {
|
||||||
throw new RuntimeException("Failed to run test: testGetRoleFound");
|
throw new RuntimeException("Failed to run test: testGetRoleFound");
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
OllamaChatMessageRole actualRole = ollamaAPI.getRole(roleName);
|
OllamaChatMessageRole actualRole = ollama.getRole(roleName);
|
||||||
assertEquals(expectedRole, actualRole);
|
assertEquals(expectedRole, actualRole);
|
||||||
verify(ollamaAPI, times(1)).getRole(roleName);
|
verify(ollama, times(1)).getRole(roleName);
|
||||||
} catch (RoleNotFoundException exception) {
|
} catch (RoleNotFoundException exception) {
|
||||||
throw new RuntimeException("Failed to run test: testGetRoleFound");
|
throw new RuntimeException("Failed to run test: testGetRoleFound");
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user