forked from Mirror/ollama4j
Added support for tools/function calling - specifically for Mistral's latest model.
This commit is contained in:
parent
8ef6fac28e
commit
91ee6cb4c1
@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
sidebar_position: 2
|
sidebar_position: 3
|
||||||
---
|
---
|
||||||
|
|
||||||
# Generate - Async
|
# Generate - Async
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
sidebar_position: 3
|
sidebar_position: 4
|
||||||
---
|
---
|
||||||
|
|
||||||
# Generate - With Image Files
|
# Generate - With Image Files
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
sidebar_position: 4
|
sidebar_position: 5
|
||||||
---
|
---
|
||||||
|
|
||||||
# Generate - With Image URLs
|
# Generate - With Image URLs
|
||||||
|
271
docs/docs/apis-generate/generate-with-tools.md
Normal file
271
docs/docs/apis-generate/generate-with-tools.md
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
---
|
||||||
|
sidebar_position: 2
|
||||||
|
---
|
||||||
|
|
||||||
|
# Generate - With Tools
|
||||||
|
|
||||||
|
This API lets you perform [function calling](https://docs.mistral.ai/capabilities/function_calling/) using LLMs in a
|
||||||
|
synchronous way.
|
||||||
|
This API correlates to
|
||||||
|
the [generate](https://github.com/ollama/ollama/blob/main/docs/api.md#request-raw-mode) API with `raw` mode.
|
||||||
|
|
||||||
|
:::note
|
||||||
|
|
||||||
|
This is an only an experimental implementation and has a very basic design.
|
||||||
|
|
||||||
|
Currently, built and tested for [Mistral's latest model](https://ollama.com/library/mistral) only. We could redesign
|
||||||
|
this
|
||||||
|
in the future if tooling is supported for more models with a generic interaction standard from Ollama.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
### Function Calling/Tools
|
||||||
|
|
||||||
|
Assume you want to call a method in your code based on the response generated from the model.
|
||||||
|
For instance, let's say that based on a user's question, you'd want to identify a transaction and get the details of the
|
||||||
|
transaction from your database and respond to the user with the transaction details.
|
||||||
|
|
||||||
|
You could do that with ease with the `function calling` capabilities of the models by registering your `tools`.
|
||||||
|
|
||||||
|
### Create Functions
|
||||||
|
|
||||||
|
This function takes the arguments `location` and `fuelType` and performs an operation with these arguments and returns a
|
||||||
|
value.
|
||||||
|
|
||||||
|
```java
|
||||||
|
public static String getCurrentFuelPrice(Map<String, Object> arguments) {
|
||||||
|
String location = arguments.get("location").toString();
|
||||||
|
String fuelType = arguments.get("fuelType").toString();
|
||||||
|
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This function takes the argument `city` and performs an operation with the argument and returns a
|
||||||
|
value.
|
||||||
|
|
||||||
|
```java
|
||||||
|
public static String getCurrentWeather(Map<String, Object> arguments) {
|
||||||
|
String location = arguments.get("city").toString();
|
||||||
|
return "Currently " + location + "'s weather is nice.";
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Define Tool Specifications
|
||||||
|
|
||||||
|
Lets define a sample tool specification called **Fuel Price Tool** for getting the current fuel price.
|
||||||
|
|
||||||
|
- Specify the function `name`, `description`, and `required` properties (`location` and `fuelType`).
|
||||||
|
- Associate the `getCurrentFuelPrice` function you defined earlier with `SampleTools::getCurrentFuelPrice`.
|
||||||
|
|
||||||
|
```java
|
||||||
|
MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder()
|
||||||
|
.functionName("current-fuel-price")
|
||||||
|
.functionDesc("Get current fuel price")
|
||||||
|
.props(
|
||||||
|
new MistralTools.PropsBuilder()
|
||||||
|
.withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
|
||||||
|
.withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.toolDefinition(SampleTools::getCurrentFuelPrice)
|
||||||
|
.build();
|
||||||
|
```
|
||||||
|
|
||||||
|
Lets also define a sample tool specification called **Weather Tool** for getting the current weather.
|
||||||
|
|
||||||
|
- Specify the function `name`, `description`, and `required` property (`city`).
|
||||||
|
- Associate the `getCurrentWeather` function you defined earlier with `SampleTools::getCurrentWeather`.
|
||||||
|
|
||||||
|
```java
|
||||||
|
MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder()
|
||||||
|
.functionName("current-weather")
|
||||||
|
.functionDesc("Get current weather")
|
||||||
|
.props(
|
||||||
|
new MistralTools.PropsBuilder()
|
||||||
|
.withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.toolDefinition(SampleTools::getCurrentWeather)
|
||||||
|
.build();
|
||||||
|
```
|
||||||
|
|
||||||
|
### Register the Tools
|
||||||
|
|
||||||
|
Register the defined tools (`fuel price` and `weather`) with the OllamaAPI.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollamaAPI.registerTool(fuelPriceToolSpecification);
|
||||||
|
ollamaAPI.registerTool(weatherToolSpecification);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Create prompt with Tools
|
||||||
|
|
||||||
|
`Prompt 1`: Create a prompt asking for the petrol price in Bengaluru using the defined fuel price and weather tools.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
String prompt1 = new MistralTools.PromptBuilder()
|
||||||
|
.withToolSpecification(fuelPriceToolSpecification)
|
||||||
|
.withToolSpecification(weatherToolSpecification)
|
||||||
|
.withPrompt("What is the petrol price in Bengaluru?")
|
||||||
|
.build();
|
||||||
|
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt1, false, new OptionsBuilder().build());
|
||||||
|
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
|
||||||
|
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Now, fire away your question to the model.
|
||||||
|
|
||||||
|
You will get a response similar to:
|
||||||
|
|
||||||
|
::::tip[LLM Response]
|
||||||
|
|
||||||
|
[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
|
||||||
|
|
||||||
|
::::
|
||||||
|
|
||||||
|
`Prompt 2`: Create a prompt asking for the current weather in Bengaluru using the same tools.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
String prompt2 = new MistralTools.PromptBuilder()
|
||||||
|
.withToolSpecification(fuelPriceToolSpecification)
|
||||||
|
.withToolSpecification(weatherToolSpecification)
|
||||||
|
.withPrompt("What is the current weather in Bengaluru?")
|
||||||
|
.build();
|
||||||
|
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt2, false, new OptionsBuilder().build());
|
||||||
|
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
|
||||||
|
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Again, fire away your question to the model.
|
||||||
|
|
||||||
|
You will get a response similar to:
|
||||||
|
|
||||||
|
::::tip[LLM Response]
|
||||||
|
|
||||||
|
[Response from tool 'current-weather']: Currently Bengaluru's weather is nice
|
||||||
|
::::
|
||||||
|
|
||||||
|
### Full Example
|
||||||
|
|
||||||
|
```java
|
||||||
|
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.tools.ToolDef;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.tools.MistralTools;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.tools.OllamaToolsResult;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class FunctionCallingWithMistral {
|
||||||
|
public static void main(String[] args) throws Exception {
|
||||||
|
String host = "http://localhost:11434/";
|
||||||
|
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
||||||
|
ollamaAPI.setRequestTimeoutSeconds(60);
|
||||||
|
|
||||||
|
String model = "mistral";
|
||||||
|
|
||||||
|
|
||||||
|
MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder()
|
||||||
|
.functionName("current-fuel-price")
|
||||||
|
.functionDesc("Get current fuel price")
|
||||||
|
.props(
|
||||||
|
new MistralTools.PropsBuilder()
|
||||||
|
.withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
|
||||||
|
.withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.toolDefinition(SampleTools::getCurrentFuelPrice)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder()
|
||||||
|
.functionName("current-weather")
|
||||||
|
.functionDesc("Get current weather")
|
||||||
|
.props(
|
||||||
|
new MistralTools.PropsBuilder()
|
||||||
|
.withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.toolDefinition(SampleTools::getCurrentWeather)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ollamaAPI.registerTool(fuelPriceToolSpecification);
|
||||||
|
ollamaAPI.registerTool(weatherToolSpecification);
|
||||||
|
|
||||||
|
String prompt1 = new MistralTools.PromptBuilder()
|
||||||
|
.withToolSpecification(fuelPriceToolSpecification)
|
||||||
|
.withToolSpecification(weatherToolSpecification)
|
||||||
|
.withPrompt("What is the petrol price in Bengaluru?")
|
||||||
|
.build();
|
||||||
|
String prompt2 = new MistralTools.PromptBuilder()
|
||||||
|
.withToolSpecification(fuelPriceToolSpecification)
|
||||||
|
.withToolSpecification(weatherToolSpecification)
|
||||||
|
.withPrompt("What is the current weather in Bengaluru?")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ask(ollamaAPI, model, prompt1);
|
||||||
|
ask(ollamaAPI, model, prompt2);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void ask(OllamaAPI ollamaAPI, String model, String prompt) throws OllamaBaseException, IOException, InterruptedException {
|
||||||
|
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt, false, new OptionsBuilder().build());
|
||||||
|
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
|
||||||
|
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class SampleTools {
|
||||||
|
public static String getCurrentFuelPrice(Map<String, Object> arguments) {
|
||||||
|
String location = arguments.get("location").toString();
|
||||||
|
String fuelType = arguments.get("fuelType").toString();
|
||||||
|
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String getCurrentWeather(Map<String, Object> arguments) {
|
||||||
|
String location = arguments.get("city").toString();
|
||||||
|
return "Currently " + location + "'s weather is nice.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Run this full example and you will get a response similar to:
|
||||||
|
|
||||||
|
::::tip[LLM Response]
|
||||||
|
|
||||||
|
[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
|
||||||
|
|
||||||
|
[Response from tool 'current-weather']: Currently Bengaluru's weather is nice
|
||||||
|
::::
|
||||||
|
|
||||||
|
### Room for improvement
|
||||||
|
|
||||||
|
Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool
|
||||||
|
registration. For example:
|
||||||
|
|
||||||
|
```java
|
||||||
|
|
||||||
|
@ToolSpec(name = "current-fuel-price", desc = "Get current fuel price")
|
||||||
|
public String getCurrentFuelPrice(Map<String, Object> arguments) {
|
||||||
|
String location = arguments.get("location").toString();
|
||||||
|
String fuelType = arguments.get("fuelType").toString();
|
||||||
|
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Instead of passing a map of args `Map<String, Object> arguments` to the tool functions, we could support passing
|
||||||
|
specific args separately with their data types. For example:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
public String getCurrentFuelPrice(String location, String fuelType) {
|
||||||
|
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Updating async/chat APIs with support for tool-based generation.
|
@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
sidebar_position: 5
|
sidebar_position: 6
|
||||||
---
|
---
|
||||||
|
|
||||||
# Prompt Builder
|
# Prompt Builder
|
||||||
|
3
pom.xml
3
pom.xml
@ -1,5 +1,6 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
<groupId>io.github.amithkoujalgi</groupId>
|
<groupId>io.github.amithkoujalgi</groupId>
|
||||||
|
@ -10,6 +10,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingRe
|
|||||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
|
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
|
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.request.*;
|
import io.github.amithkoujalgi.ollama4j.core.models.request.*;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.tools.*;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
|
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
@ -25,9 +26,7 @@ import java.net.http.HttpResponse;
|
|||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.Base64;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The base Ollama API class.
|
* The base Ollama API class.
|
||||||
@ -339,6 +338,7 @@ public class OllamaAPI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate response for a question to a model running on Ollama server. This is a sync/blocking
|
* Generate response for a question to a model running on Ollama server. This is a sync/blocking
|
||||||
* call.
|
* call.
|
||||||
@ -351,9 +351,10 @@ public class OllamaAPI {
|
|||||||
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
|
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
|
||||||
* @return OllamaResult that includes response text and time taken for response
|
* @return OllamaResult that includes response text and time taken for response
|
||||||
*/
|
*/
|
||||||
public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler)
|
public OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
|
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
|
||||||
|
ollamaRequestModel.setRaw(raw);
|
||||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||||
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
|
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
|
||||||
}
|
}
|
||||||
@ -361,13 +362,37 @@ public class OllamaAPI {
|
|||||||
/**
|
/**
|
||||||
* Convenience method to call Ollama API without streaming responses.
|
* Convenience method to call Ollama API without streaming responses.
|
||||||
* <p>
|
* <p>
|
||||||
* Uses {@link #generate(String, String, Options, OllamaStreamHandler)}
|
* Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
|
||||||
|
*
|
||||||
|
* @param model Model to use
|
||||||
|
* @param prompt Prompt text
|
||||||
|
* @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context.
|
||||||
|
* @param options Additional Options
|
||||||
|
* @return OllamaResult
|
||||||
*/
|
*/
|
||||||
public OllamaResult generate(String model, String prompt, Options options)
|
public OllamaResult generate(String model, String prompt, boolean raw, Options options)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
return generate(model, prompt, options, null);
|
return generate(model, prompt, raw, options, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options)
|
||||||
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
|
OllamaToolsResult toolResult = new OllamaToolsResult();
|
||||||
|
Map<ToolDef, Object> toolResults = new HashMap<>();
|
||||||
|
|
||||||
|
OllamaResult result = generate(model, prompt, raw, options, null);
|
||||||
|
toolResult.setModelResult(result);
|
||||||
|
|
||||||
|
List<ToolDef> toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class));
|
||||||
|
for (ToolDef toolDef : toolDefs) {
|
||||||
|
toolResults.put(toolDef, invokeTool(toolDef));
|
||||||
|
}
|
||||||
|
toolResult.setToolResults(toolResults);
|
||||||
|
return toolResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate response for a question to a model running on Ollama server and get a callback handle
|
* Generate response for a question to a model running on Ollama server and get a callback handle
|
||||||
* that can be used to check for status and get the response from the model later. This would be
|
* that can be used to check for status and get the response from the model later. This would be
|
||||||
@ -377,9 +402,9 @@ public class OllamaAPI {
|
|||||||
* @param prompt the prompt/question text
|
* @param prompt the prompt/question text
|
||||||
* @return the ollama async result callback handle
|
* @return the ollama async result callback handle
|
||||||
*/
|
*/
|
||||||
public OllamaAsyncResultCallback generateAsync(String model, String prompt) {
|
public OllamaAsyncResultCallback generateAsync(String model, String prompt, boolean raw) {
|
||||||
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
|
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
|
||||||
|
ollamaRequestModel.setRaw(raw);
|
||||||
URI uri = URI.create(this.host + "/api/generate");
|
URI uri = URI.create(this.host + "/api/generate");
|
||||||
OllamaAsyncResultCallback ollamaAsyncResultCallback =
|
OllamaAsyncResultCallback ollamaAsyncResultCallback =
|
||||||
new OllamaAsyncResultCallback(
|
new OllamaAsyncResultCallback(
|
||||||
@ -576,4 +601,24 @@ public class OllamaAPI {
|
|||||||
private boolean isBasicAuthCredentialsSet() {
|
private boolean isBasicAuthCredentialsSet() {
|
||||||
return basicAuth != null;
|
return basicAuth != null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public void registerTool(MistralTools.ToolSpecification toolSpecification) {
|
||||||
|
ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
|
||||||
|
}
|
||||||
|
|
||||||
|
private Object invokeTool(ToolDef toolDef) {
|
||||||
|
try {
|
||||||
|
String methodName = toolDef.getName();
|
||||||
|
Map<String, Object> arguments = toolDef.getArguments();
|
||||||
|
DynamicFunction function = ToolRegistry.getFunction(methodName);
|
||||||
|
if (function == null) {
|
||||||
|
throw new IllegalArgumentException("No such tool: " + methodName);
|
||||||
|
}
|
||||||
|
return function.apply(arguments);
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
return "Error calling tool: " + e.getMessage();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,5 @@
|
|||||||
package io.github.amithkoujalgi.ollama4j.core.models.request;
|
package io.github.amithkoujalgi.ollama4j.core.models.request;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
|
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||||
@ -13,15 +9,19 @@ import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRespo
|
|||||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
|
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
|
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
|
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
|
||||||
|
|
||||||
private OllamaGenerateStreamObserver streamObserver;
|
private OllamaGenerateStreamObserver streamObserver;
|
||||||
|
|
||||||
public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
||||||
super(host, basicAuth, requestTimeoutSeconds, verbose);
|
super(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -31,24 +31,22 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
|
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
|
||||||
try {
|
try {
|
||||||
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
|
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
|
||||||
responseBuffer.append(ollamaResponseModel.getResponse());
|
responseBuffer.append(ollamaResponseModel.getResponse());
|
||||||
if(streamObserver != null) {
|
if (streamObserver != null) {
|
||||||
streamObserver.notify(ollamaResponseModel);
|
streamObserver.notify(ollamaResponseModel);
|
||||||
}
|
}
|
||||||
return ollamaResponseModel.isDone();
|
return ollamaResponseModel.isDone();
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
LOG.error("Error parsing the Ollama chat response!",e);
|
LOG.error("Error parsing the Ollama chat response!", e);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
|
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
|
||||||
return super.callSync(body);
|
return super.callSync(body);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,8 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.core.tools;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@FunctionalInterface
|
||||||
|
public interface DynamicFunction {
|
||||||
|
Object apply(Map<String, Object> arguments);
|
||||||
|
}
|
@ -0,0 +1,139 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.core.tools;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class MistralTools {
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
public static class ToolSpecification {
|
||||||
|
private String functionName;
|
||||||
|
private String functionDesc;
|
||||||
|
private Map<String, PromptFuncDefinition.Property> props;
|
||||||
|
private DynamicFunction toolDefinition;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||||
|
public static class PromptFuncDefinition {
|
||||||
|
private String type;
|
||||||
|
private PromptFuncSpec function;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class PromptFuncSpec {
|
||||||
|
private String name;
|
||||||
|
private String description;
|
||||||
|
private Parameters parameters;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class Parameters {
|
||||||
|
private String type;
|
||||||
|
private Map<String, Property> properties;
|
||||||
|
private List<String> required;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
public static class Property {
|
||||||
|
private String type;
|
||||||
|
private String description;
|
||||||
|
@JsonProperty("enum")
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
private List<String> enumValues;
|
||||||
|
@JsonIgnore
|
||||||
|
private boolean required;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class PropsBuilder {
|
||||||
|
private final Map<String, PromptFuncDefinition.Property> props = new HashMap<>();
|
||||||
|
|
||||||
|
public PropsBuilder withProperty(String key, PromptFuncDefinition.Property property) {
|
||||||
|
props.put(key, property);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<String, PromptFuncDefinition.Property> build() {
|
||||||
|
return props;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class PromptBuilder {
|
||||||
|
private final List<PromptFuncDefinition> tools = new ArrayList<>();
|
||||||
|
|
||||||
|
private String promptText;
|
||||||
|
|
||||||
|
public String build() throws JsonProcessingException {
|
||||||
|
return "[AVAILABLE_TOOLS] " + Utils.getObjectMapper().writeValueAsString(tools) + "[/AVAILABLE_TOOLS][INST] " + promptText + " [/INST]";
|
||||||
|
}
|
||||||
|
|
||||||
|
public PromptBuilder withPrompt(String prompt) throws JsonProcessingException {
|
||||||
|
promptText = prompt;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PromptBuilder withToolSpecification(ToolSpecification spec) {
|
||||||
|
PromptFuncDefinition def = new PromptFuncDefinition();
|
||||||
|
def.setType("function");
|
||||||
|
|
||||||
|
PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
|
||||||
|
functionDetail.setName(spec.getFunctionName());
|
||||||
|
functionDetail.setDescription(spec.getFunctionDesc());
|
||||||
|
|
||||||
|
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
|
||||||
|
parameters.setType("object");
|
||||||
|
parameters.setProperties(spec.getProps());
|
||||||
|
|
||||||
|
List<String> requiredValues = new ArrayList<>();
|
||||||
|
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProps().entrySet()) {
|
||||||
|
if (p.getValue().isRequired()) {
|
||||||
|
requiredValues.add(p.getKey());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parameters.setRequired(requiredValues);
|
||||||
|
functionDetail.setParameters(parameters);
|
||||||
|
def.setFunction(functionDetail);
|
||||||
|
|
||||||
|
tools.add(def);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map<String, PromptFuncDefinition.Property> props) {
|
||||||
|
// PromptFuncDefinition def = new PromptFuncDefinition();
|
||||||
|
// def.setType("function");
|
||||||
|
//
|
||||||
|
// PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
|
||||||
|
// functionDetail.setName(functionName);
|
||||||
|
// functionDetail.setDescription(functionDesc);
|
||||||
|
//
|
||||||
|
// PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
|
||||||
|
// parameters.setType("object");
|
||||||
|
// parameters.setProperties(props);
|
||||||
|
//
|
||||||
|
// List<String> requiredValues = new ArrayList<>();
|
||||||
|
// for (Map.Entry<String, PromptFuncDefinition.Property> p : props.entrySet()) {
|
||||||
|
// if (p.getValue().isRequired()) {
|
||||||
|
// requiredValues.add(p.getKey());
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// parameters.setRequired(requiredValues);
|
||||||
|
// functionDetail.setParameters(parameters);
|
||||||
|
// def.setFunction(functionDetail);
|
||||||
|
//
|
||||||
|
// tools.add(def);
|
||||||
|
// return this;
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.core.tools;
|
||||||
|
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class OllamaToolsResult {
|
||||||
|
private OllamaResult modelResult;
|
||||||
|
private Map<ToolDef, Object> toolResults;
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.core.tools;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class ToolDef {
|
||||||
|
|
||||||
|
private String name;
|
||||||
|
private Map<String, Object> arguments;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,17 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.core.tools;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class ToolRegistry {
|
||||||
|
private static final Map<String, DynamicFunction> functionMap = new HashMap<>();
|
||||||
|
|
||||||
|
|
||||||
|
public static DynamicFunction getFunction(String name) {
|
||||||
|
return functionMap.get(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void addFunction(String name, DynamicFunction function) {
|
||||||
|
functionMap.put(name, function);
|
||||||
|
}
|
||||||
|
}
|
@ -1,7 +1,5 @@
|
|||||||
package io.github.amithkoujalgi.ollama4j.integrationtests;
|
package io.github.amithkoujalgi.ollama4j.integrationtests;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
|
||||||
|
|
||||||
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
|
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
|
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
|
||||||
@ -10,9 +8,16 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
|
|||||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
|
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
|
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
|
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
|
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
|
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Order;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
@ -22,372 +27,369 @@ import java.net.http.HttpConnectTimeoutException;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Properties;
|
import java.util.Properties;
|
||||||
import lombok.Data;
|
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import org.junit.jupiter.api.Order;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
class TestRealAPIs {
|
class TestRealAPIs {
|
||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
|
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
|
||||||
|
|
||||||
OllamaAPI ollamaAPI;
|
OllamaAPI ollamaAPI;
|
||||||
Config config;
|
Config config;
|
||||||
|
|
||||||
private File getImageFileFromClasspath(String fileName) {
|
private File getImageFileFromClasspath(String fileName) {
|
||||||
ClassLoader classLoader = getClass().getClassLoader();
|
ClassLoader classLoader = getClass().getClassLoader();
|
||||||
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
|
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
|
||||||
}
|
|
||||||
|
|
||||||
@BeforeEach
|
|
||||||
void setUp() {
|
|
||||||
config = new Config();
|
|
||||||
ollamaAPI = new OllamaAPI(config.getOllamaURL());
|
|
||||||
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@Order(1)
|
|
||||||
void testWrongEndpoint() {
|
|
||||||
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
|
|
||||||
assertThrows(ConnectException.class, ollamaAPI::listModels);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@Order(1)
|
|
||||||
void testEndpointReachability() {
|
|
||||||
try {
|
|
||||||
assertNotNull(ollamaAPI.listModels());
|
|
||||||
} catch (HttpConnectTimeoutException e) {
|
|
||||||
fail(e.getMessage());
|
|
||||||
} catch (Exception e) {
|
|
||||||
fail(e);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@BeforeEach
|
||||||
@Order(2)
|
void setUp() {
|
||||||
void testListModels() {
|
config = new Config();
|
||||||
testEndpointReachability();
|
ollamaAPI = new OllamaAPI(config.getOllamaURL());
|
||||||
try {
|
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
|
||||||
assertNotNull(ollamaAPI.listModels());
|
|
||||||
ollamaAPI.listModels().forEach(System.out::println);
|
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
|
||||||
fail(e);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(2)
|
@Order(1)
|
||||||
void testPullModel() {
|
void testWrongEndpoint() {
|
||||||
testEndpointReachability();
|
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
|
||||||
try {
|
assertThrows(ConnectException.class, ollamaAPI::listModels);
|
||||||
ollamaAPI.pullModel(config.getModel());
|
|
||||||
boolean found =
|
|
||||||
ollamaAPI.listModels().stream()
|
|
||||||
.anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
|
|
||||||
assertTrue(found);
|
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
|
||||||
fail(e);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(1)
|
||||||
void testListDtails() {
|
void testEndpointReachability() {
|
||||||
testEndpointReachability();
|
try {
|
||||||
try {
|
assertNotNull(ollamaAPI.listModels());
|
||||||
ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel());
|
} catch (HttpConnectTimeoutException e) {
|
||||||
assertNotNull(modelDetails);
|
fail(e.getMessage());
|
||||||
System.out.println(modelDetails);
|
} catch (Exception e) {
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
fail(e);
|
||||||
fail(e);
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(2)
|
||||||
void testAskModelWithDefaultOptions() {
|
void testListModels() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
OllamaResult result =
|
assertNotNull(ollamaAPI.listModels());
|
||||||
ollamaAPI.generate(
|
ollamaAPI.listModels().forEach(System.out::println);
|
||||||
config.getModel(),
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
"What is the capital of France? And what's France's connection with Mona Lisa?",
|
fail(e);
|
||||||
new OptionsBuilder().build());
|
}
|
||||||
assertNotNull(result);
|
|
||||||
assertNotNull(result.getResponse());
|
|
||||||
assertFalse(result.getResponse().isEmpty());
|
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
|
||||||
fail(e);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(2)
|
||||||
void testAskModelWithDefaultOptionsStreamed() {
|
void testPullModel() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
|
ollamaAPI.pullModel(config.getModel());
|
||||||
StringBuffer sb = new StringBuffer("");
|
boolean found =
|
||||||
|
ollamaAPI.listModels().stream()
|
||||||
OllamaResult result = ollamaAPI.generate(config.getModel(),
|
.anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
|
||||||
"What is the capital of France? And what's France's connection with Mona Lisa?",
|
assertTrue(found);
|
||||||
new OptionsBuilder().build(), (s) -> {
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
LOG.info(s);
|
fail(e);
|
||||||
String substring = s.substring(sb.toString().length(), s.length());
|
}
|
||||||
LOG.info(substring);
|
|
||||||
sb.append(substring);
|
|
||||||
});
|
|
||||||
|
|
||||||
assertNotNull(result);
|
|
||||||
assertNotNull(result.getResponse());
|
|
||||||
assertFalse(result.getResponse().isEmpty());
|
|
||||||
assertEquals(sb.toString().trim(), result.getResponse().trim());
|
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
|
||||||
fail(e);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testAskModelWithOptions() {
|
void testListDtails() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
OllamaResult result =
|
ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel());
|
||||||
ollamaAPI.generate(
|
assertNotNull(modelDetails);
|
||||||
config.getModel(),
|
System.out.println(modelDetails);
|
||||||
"What is the capital of France? And what's France's connection with Mona Lisa?",
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
new OptionsBuilder().setTemperature(0.9f).build());
|
fail(e);
|
||||||
assertNotNull(result);
|
}
|
||||||
assertNotNull(result.getResponse());
|
|
||||||
assertFalse(result.getResponse().isEmpty());
|
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
|
||||||
fail(e);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testChat() {
|
void testAskModelWithDefaultOptions() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
OllamaResult result =
|
||||||
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
|
ollamaAPI.generate(
|
||||||
.withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
|
config.getModel(),
|
||||||
.withMessage(OllamaChatMessageRole.USER,"And what is the second larges city?")
|
"What is the capital of France? And what's France's connection with Mona Lisa?",
|
||||||
.build();
|
false,
|
||||||
|
new OptionsBuilder().build());
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
assertNotNull(result);
|
||||||
assertNotNull(chatResult);
|
assertNotNull(result.getResponse());
|
||||||
assertFalse(chatResult.getResponse().isBlank());
|
assertFalse(result.getResponse().isEmpty());
|
||||||
assertEquals(4,chatResult.getChatHistory().size());
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
fail(e);
|
||||||
fail(e);
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testChatWithSystemPrompt() {
|
void testAskModelWithDefaultOptionsStreamed() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
StringBuffer sb = new StringBuffer("");
|
||||||
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
|
OllamaResult result = ollamaAPI.generate(config.getModel(),
|
||||||
"You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
|
"What is the capital of France? And what's France's connection with Mona Lisa?",
|
||||||
.withMessage(OllamaChatMessageRole.USER,
|
false,
|
||||||
"What is the capital of France? And what's France's connection with Mona Lisa?")
|
new OptionsBuilder().build(), (s) -> {
|
||||||
.build();
|
LOG.info(s);
|
||||||
|
String substring = s.substring(sb.toString().length(), s.length());
|
||||||
|
LOG.info(substring);
|
||||||
|
sb.append(substring);
|
||||||
|
});
|
||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
assertNotNull(result);
|
||||||
assertNotNull(chatResult);
|
assertNotNull(result.getResponse());
|
||||||
assertFalse(chatResult.getResponse().isBlank());
|
assertFalse(result.getResponse().isEmpty());
|
||||||
assertTrue(chatResult.getResponse().startsWith("NI"));
|
assertEquals(sb.toString().trim(), result.getResponse().trim());
|
||||||
assertEquals(3, chatResult.getChatHistory().size());
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
fail(e);
|
||||||
fail(e);
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testChatWithStream() {
|
void testAskModelWithOptions() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
OllamaResult result =
|
||||||
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
ollamaAPI.generate(
|
||||||
"What is the capital of France? And what's France's connection with Mona Lisa?")
|
config.getModel(),
|
||||||
.build();
|
"What is the capital of France? And what's France's connection with Mona Lisa?",
|
||||||
|
true,
|
||||||
StringBuffer sb = new StringBuffer("");
|
new OptionsBuilder().setTemperature(0.9f).build());
|
||||||
|
assertNotNull(result);
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> {
|
assertNotNull(result.getResponse());
|
||||||
LOG.info(s);
|
assertFalse(result.getResponse().isEmpty());
|
||||||
String substring = s.substring(sb.toString().length(), s.length());
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
LOG.info(substring);
|
fail(e);
|
||||||
sb.append(substring);
|
}
|
||||||
});
|
|
||||||
assertNotNull(chatResult);
|
|
||||||
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
|
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
|
||||||
fail(e);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testChatWithImageFromFileWithHistoryRecognition() {
|
void testChat() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
OllamaChatRequestBuilder builder =
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||||
OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
|
||||||
OllamaChatRequestModel requestModel =
|
.withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
|
||||||
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
.withMessage(OllamaChatMessageRole.USER, "And what is the second larges city?")
|
||||||
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
|
.build();
|
||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
assertNotNull(chatResult.getResponse());
|
assertFalse(chatResult.getResponse().isBlank());
|
||||||
|
assertEquals(4, chatResult.getChatHistory().size());
|
||||||
builder.reset();
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
requestModel =
|
}
|
||||||
builder.withMessages(chatResult.getChatHistory())
|
|
||||||
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
|
|
||||||
|
|
||||||
chatResult = ollamaAPI.chat(requestModel);
|
|
||||||
assertNotNull(chatResult);
|
|
||||||
assertNotNull(chatResult.getResponse());
|
|
||||||
|
|
||||||
|
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
|
||||||
fail(e);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testChatWithImageFromURL() {
|
void testChatWithSystemPrompt() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||||
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
|
||||||
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
|
"You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
|
||||||
.build();
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
|
"What is the capital of France? And what's France's connection with Mona Lisa?")
|
||||||
|
.build();
|
||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
assertFalse(chatResult.getResponse().isBlank());
|
||||||
fail(e);
|
assertTrue(chatResult.getResponse().startsWith("NI"));
|
||||||
|
assertEquals(3, chatResult.getChatHistory().size());
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testAskModelWithOptionsAndImageFiles() {
|
void testChatWithStream() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
|
try {
|
||||||
try {
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||||
OllamaResult result =
|
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||||
ollamaAPI.generateWithImageFiles(
|
"What is the capital of France? And what's France's connection with Mona Lisa?")
|
||||||
config.getImageModel(),
|
.build();
|
||||||
"What is in this image?",
|
|
||||||
List.of(imageFile),
|
StringBuffer sb = new StringBuffer("");
|
||||||
new OptionsBuilder().build());
|
|
||||||
assertNotNull(result);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
|
||||||
assertNotNull(result.getResponse());
|
LOG.info(s);
|
||||||
assertFalse(result.getResponse().isEmpty());
|
String substring = s.substring(sb.toString().length(), s.length());
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
LOG.info(substring);
|
||||||
fail(e);
|
sb.append(substring);
|
||||||
|
});
|
||||||
|
assertNotNull(chatResult);
|
||||||
|
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testAskModelWithOptionsAndImageFilesStreamed() {
|
void testChatWithImageFromFileWithHistoryRecognition() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
|
try {
|
||||||
try {
|
OllamaChatRequestBuilder builder =
|
||||||
StringBuffer sb = new StringBuffer("");
|
OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
||||||
|
OllamaChatRequestModel requestModel =
|
||||||
|
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
||||||
|
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
|
||||||
|
|
||||||
OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
"What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
|
assertNotNull(chatResult);
|
||||||
LOG.info(s);
|
assertNotNull(chatResult.getResponse());
|
||||||
String substring = s.substring(sb.toString().length(), s.length());
|
|
||||||
LOG.info(substring);
|
builder.reset();
|
||||||
sb.append(substring);
|
|
||||||
});
|
requestModel =
|
||||||
assertNotNull(result);
|
builder.withMessages(chatResult.getChatHistory())
|
||||||
assertNotNull(result.getResponse());
|
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
|
||||||
assertFalse(result.getResponse().isEmpty());
|
|
||||||
assertEquals(sb.toString().trim(), result.getResponse().trim());
|
chatResult = ollamaAPI.chat(requestModel);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
assertNotNull(chatResult);
|
||||||
fail(e);
|
assertNotNull(chatResult.getResponse());
|
||||||
|
|
||||||
|
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testAskModelWithOptionsAndImageURLs() {
|
void testChatWithImageFromURL() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
OllamaResult result =
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
||||||
ollamaAPI.generateWithImageURLs(
|
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
||||||
config.getImageModel(),
|
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
|
||||||
"What is in this image?",
|
.build();
|
||||||
List.of(
|
|
||||||
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
new OptionsBuilder().build());
|
assertNotNull(chatResult);
|
||||||
assertNotNull(result);
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
assertNotNull(result.getResponse());
|
fail(e);
|
||||||
assertFalse(result.getResponse().isEmpty());
|
}
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
|
||||||
fail(e);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
public void testEmbedding() {
|
void testAskModelWithOptionsAndImageFiles() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
|
||||||
OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
|
try {
|
||||||
.getInstance(config.getModel(), "What is the capital of France?").build();
|
OllamaResult result =
|
||||||
|
ollamaAPI.generateWithImageFiles(
|
||||||
List<Double> embeddings = ollamaAPI.generateEmbeddings(request);
|
config.getImageModel(),
|
||||||
|
"What is in this image?",
|
||||||
assertNotNull(embeddings);
|
List.of(imageFile),
|
||||||
assertFalse(embeddings.isEmpty());
|
new OptionsBuilder().build());
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
assertNotNull(result);
|
||||||
fail(e);
|
assertNotNull(result.getResponse());
|
||||||
|
assertFalse(result.getResponse().isEmpty());
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(3)
|
||||||
|
void testAskModelWithOptionsAndImageFilesStreamed() {
|
||||||
|
testEndpointReachability();
|
||||||
|
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
|
||||||
|
try {
|
||||||
|
StringBuffer sb = new StringBuffer("");
|
||||||
|
|
||||||
|
OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
|
||||||
|
"What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
|
||||||
|
LOG.info(s);
|
||||||
|
String substring = s.substring(sb.toString().length(), s.length());
|
||||||
|
LOG.info(substring);
|
||||||
|
sb.append(substring);
|
||||||
|
});
|
||||||
|
assertNotNull(result);
|
||||||
|
assertNotNull(result.getResponse());
|
||||||
|
assertFalse(result.getResponse().isEmpty());
|
||||||
|
assertEquals(sb.toString().trim(), result.getResponse().trim());
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(3)
|
||||||
|
void testAskModelWithOptionsAndImageURLs() {
|
||||||
|
testEndpointReachability();
|
||||||
|
try {
|
||||||
|
OllamaResult result =
|
||||||
|
ollamaAPI.generateWithImageURLs(
|
||||||
|
config.getImageModel(),
|
||||||
|
"What is in this image?",
|
||||||
|
List.of(
|
||||||
|
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
|
||||||
|
new OptionsBuilder().build());
|
||||||
|
assertNotNull(result);
|
||||||
|
assertNotNull(result.getResponse());
|
||||||
|
assertFalse(result.getResponse().isEmpty());
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(3)
|
||||||
|
public void testEmbedding() {
|
||||||
|
testEndpointReachability();
|
||||||
|
try {
|
||||||
|
OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
|
||||||
|
.getInstance(config.getModel(), "What is the capital of France?").build();
|
||||||
|
|
||||||
|
List<Double> embeddings = ollamaAPI.generateEmbeddings(request);
|
||||||
|
|
||||||
|
assertNotNull(embeddings);
|
||||||
|
assertFalse(embeddings.isEmpty());
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
class Config {
|
class Config {
|
||||||
private String ollamaURL;
|
private String ollamaURL;
|
||||||
private String model;
|
private String model;
|
||||||
private String imageModel;
|
private String imageModel;
|
||||||
private int requestTimeoutSeconds;
|
private int requestTimeoutSeconds;
|
||||||
|
|
||||||
public Config() {
|
public Config() {
|
||||||
Properties properties = new Properties();
|
Properties properties = new Properties();
|
||||||
try (InputStream input =
|
try (InputStream input =
|
||||||
getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
|
getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
|
||||||
if (input == null) {
|
if (input == null) {
|
||||||
throw new RuntimeException("Sorry, unable to find test-config.properties");
|
throw new RuntimeException("Sorry, unable to find test-config.properties");
|
||||||
}
|
}
|
||||||
properties.load(input);
|
properties.load(input);
|
||||||
this.ollamaURL = properties.getProperty("ollama.url");
|
this.ollamaURL = properties.getProperty("ollama.url");
|
||||||
this.model = properties.getProperty("ollama.model");
|
this.model = properties.getProperty("ollama.model");
|
||||||
this.imageModel = properties.getProperty("ollama.model.image");
|
this.imageModel = properties.getProperty("ollama.model.image");
|
||||||
this.requestTimeoutSeconds =
|
this.requestTimeoutSeconds =
|
||||||
Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
|
Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException("Error loading properties", e);
|
throw new RuntimeException("Error loading properties", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
package io.github.amithkoujalgi.ollama4j.unittests;
|
package io.github.amithkoujalgi.ollama4j.unittests;
|
||||||
|
|
||||||
import static org.mockito.Mockito.*;
|
|
||||||
|
|
||||||
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
|
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
|
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
|
||||||
@ -9,155 +7,158 @@ import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
|
|||||||
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
|
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
|
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
|
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.mockito.Mockito;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.URISyntaxException;
|
import java.net.URISyntaxException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.mockito.Mockito;
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
class TestMockedAPIs {
|
class TestMockedAPIs {
|
||||||
@Test
|
@Test
|
||||||
void testPullModel() {
|
void testPullModel() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
try {
|
try {
|
||||||
doNothing().when(ollamaAPI).pullModel(model);
|
doNothing().when(ollamaAPI).pullModel(model);
|
||||||
ollamaAPI.pullModel(model);
|
ollamaAPI.pullModel(model);
|
||||||
verify(ollamaAPI, times(1)).pullModel(model);
|
verify(ollamaAPI, times(1)).pullModel(model);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testListModels() {
|
void testListModels() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
|
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
|
||||||
ollamaAPI.listModels();
|
ollamaAPI.listModels();
|
||||||
verify(ollamaAPI, times(1)).listModels();
|
verify(ollamaAPI, times(1)).listModels();
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testCreateModel() {
|
void testCreateModel() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
|
String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
|
||||||
try {
|
try {
|
||||||
doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
|
doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
|
||||||
ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
|
ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
|
||||||
verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
|
verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testDeleteModel() {
|
void testDeleteModel() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
try {
|
try {
|
||||||
doNothing().when(ollamaAPI).deleteModel(model, true);
|
doNothing().when(ollamaAPI).deleteModel(model, true);
|
||||||
ollamaAPI.deleteModel(model, true);
|
ollamaAPI.deleteModel(model, true);
|
||||||
verify(ollamaAPI, times(1)).deleteModel(model, true);
|
verify(ollamaAPI, times(1)).deleteModel(model, true);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testGetModelDetails() {
|
void testGetModelDetails() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
|
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
|
||||||
ollamaAPI.getModelDetails(model);
|
ollamaAPI.getModelDetails(model);
|
||||||
verify(ollamaAPI, times(1)).getModelDetails(model);
|
verify(ollamaAPI, times(1)).getModelDetails(model);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testGenerateEmbeddings() {
|
void testGenerateEmbeddings() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
|
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
|
||||||
ollamaAPI.generateEmbeddings(model, prompt);
|
ollamaAPI.generateEmbeddings(model, prompt);
|
||||||
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
|
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAsk() {
|
void testAsk() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
OptionsBuilder optionsBuilder = new OptionsBuilder();
|
OptionsBuilder optionsBuilder = new OptionsBuilder();
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.generate(model, prompt, optionsBuilder.build()))
|
when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build()))
|
||||||
.thenReturn(new OllamaResult("", 0, 200));
|
.thenReturn(new OllamaResult("", 0, 200));
|
||||||
ollamaAPI.generate(model, prompt, optionsBuilder.build());
|
ollamaAPI.generate(model, prompt, false, optionsBuilder.build());
|
||||||
verify(ollamaAPI, times(1)).generate(model, prompt, optionsBuilder.build());
|
verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build());
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAskWithImageFiles() {
|
void testAskWithImageFiles() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.generateWithImageFiles(
|
when(ollamaAPI.generateWithImageFiles(
|
||||||
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
|
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
|
||||||
.thenReturn(new OllamaResult("", 0, 200));
|
.thenReturn(new OllamaResult("", 0, 200));
|
||||||
ollamaAPI.generateWithImageFiles(
|
ollamaAPI.generateWithImageFiles(
|
||||||
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
|
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
|
||||||
verify(ollamaAPI, times(1))
|
verify(ollamaAPI, times(1))
|
||||||
.generateWithImageFiles(
|
.generateWithImageFiles(
|
||||||
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
|
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAskWithImageURLs() {
|
void testAskWithImageURLs() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
try {
|
try {
|
||||||
when(ollamaAPI.generateWithImageURLs(
|
when(ollamaAPI.generateWithImageURLs(
|
||||||
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
|
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
|
||||||
.thenReturn(new OllamaResult("", 0, 200));
|
.thenReturn(new OllamaResult("", 0, 200));
|
||||||
ollamaAPI.generateWithImageURLs(
|
ollamaAPI.generateWithImageURLs(
|
||||||
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
|
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
|
||||||
verify(ollamaAPI, times(1))
|
verify(ollamaAPI, times(1))
|
||||||
.generateWithImageURLs(
|
.generateWithImageURLs(
|
||||||
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
|
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAskAsync() {
|
void testAskAsync() {
|
||||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||||
String model = OllamaModelType.LLAMA2;
|
String model = OllamaModelType.LLAMA2;
|
||||||
String prompt = "some prompt text";
|
String prompt = "some prompt text";
|
||||||
when(ollamaAPI.generateAsync(model, prompt))
|
when(ollamaAPI.generateAsync(model, prompt, false))
|
||||||
.thenReturn(new OllamaAsyncResultCallback(null, null, 3));
|
.thenReturn(new OllamaAsyncResultCallback(null, null, 3));
|
||||||
ollamaAPI.generateAsync(model, prompt);
|
ollamaAPI.generateAsync(model, prompt, false);
|
||||||
verify(ollamaAPI, times(1)).generateAsync(model, prompt);
|
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user