Compare commits

..

9 Commits

Author SHA1 Message Date
koujalgi.amith@gmail.com
81689be194 Refactored tools API
Signed-off-by: koujalgi.amith@gmail.com <koujalgi.amith@gmail.com>
2024-07-14 11:23:36 +05:30
koujalgi.amith@gmail.com
fd93036d08 Refactor
Signed-off-by: koujalgi.amith@gmail.com <koujalgi.amith@gmail.com>
2024-07-14 00:07:23 +05:30
koujalgi.amith@gmail.com
c9b05a725b Refactor
Signed-off-by: koujalgi.amith@gmail.com <koujalgi.amith@gmail.com>
2024-07-14 00:05:43 +05:30
koujalgi.amith@gmail.com
a4e1b4afe9 Removed old maven-publish.yml
Signed-off-by: koujalgi.amith@gmail.com <koujalgi.amith@gmail.com>
2024-07-14 00:02:20 +05:30
koujalgi.amith@gmail.com
3d21813abb updated README.md
Signed-off-by: koujalgi.amith@gmail.com <koujalgi.amith@gmail.com>
2024-07-14 00:00:52 +05:30
koujalgi.amith@gmail.com
383d0f56ca Updated generateAsync() API
Signed-off-by: koujalgi.amith@gmail.com <koujalgi.amith@gmail.com>
2024-07-13 23:54:49 +05:30
koujalgi.amith@gmail.com
af1b213a76 updated README.md
Signed-off-by: koujalgi.amith@gmail.com <koujalgi.amith@gmail.com>
2024-07-13 21:50:45 +05:30
koujalgi.amith@gmail.com
fed89a9643 updated README.md
Signed-off-by: koujalgi.amith@gmail.com <koujalgi.amith@gmail.com>
2024-07-13 21:49:26 +05:30
Amith Koujalgi
fd32aa33ff Updated README.md
Signed-off-by: Amith Koujalgi <koujalgi.amith@gmail.com>
2024-07-13 14:30:13 +05:30
26 changed files with 479 additions and 426 deletions

View File

@@ -2,7 +2,10 @@
### Ollama4j
<img src='https://raw.githubusercontent.com/amithkoujalgi/ollama4j/65a9d526150da8fcd98e2af6a164f055572bf722/ollama4j.jpeg' width='100' alt="ollama4j-icon">
<p align="center">
<img src='https://raw.githubusercontent.com/amithkoujalgi/ollama4j/65a9d526150da8fcd98e2af6a164f055572bf722/ollama4j.jpeg' width='100' alt="ollama4j-icon">
</p>
A Java library (wrapper/binding) for [Ollama](https://ollama.ai/) server.
@@ -93,7 +96,7 @@ according to your requirements.
<dependency>
<groupId>io.github.amithkoujalgi</groupId>
<artifactId>ollama4j</artifactId>
<version>1.0.74</version>
<version>1.0.74</version>
</dependency>
```
@@ -193,8 +196,7 @@ make it
#### Releases
Releases (newer artifact versions) are done automatically on pushing the code to the `main` branch through GitHub
Actions CI workflow.
Newer artifacts are published via GitHub Actions CI workflow when a new release is created from `main` branch.
#### Who's using Ollama4j?
@@ -251,19 +253,16 @@ of contribution is much appreciated.
The nomenclature and the icon have been adopted from the incredible [Ollama](https://ollama.ai/)
project.
<div style="text-align: center">
**Thanks to the amazing contributors**
<a href="https://github.com/amithkoujalgi/ollama4j/graphs/contributors">
<img src="https://contrib.rocks/image?repo=amithkoujalgi/ollama4j" />
</a>
<p align="center">
<a href="https://github.com/amithkoujalgi/ollama4j/graphs/contributors">
<img src="https://contrib.rocks/image?repo=amithkoujalgi/ollama4j" />
</a>
</p>
### Appreciate my work?
<a href="https://www.buymeacoffee.com/amithkoujalgi" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>
</div>
<p align="center">
<a href="https://www.buymeacoffee.com/amithkoujalgi" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>
</p>

View File

@@ -1,42 +1,46 @@
---
sidebar_position: 3
sidebar_position: 2
---
# Generate - Async
This API lets you ask questions to the LLMs in a asynchronous way.
These APIs correlate to
the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) APIs.
This is particularly helpful when you want to issue a generate request to the LLM and collect the response in the
background (such as threads) without blocking your code until the response arrives from the model.
This API corresponds to
the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) API.
```java
public class Main {
public static void main(String[] args) {
public static void main(String[] args) throws Exception {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.setRequestTimeoutSeconds(60);
String prompt = "List all cricket world cup teams of 2019.";
OllamaAsyncResultStreamer streamer = ollamaAPI.generateAsync(OllamaModelType.LLAMA3, prompt, false);
String prompt = "Who are you?";
// Set the poll interval according to your needs.
// Smaller the poll interval, more frequently you receive the tokens.
int pollIntervalMilliseconds = 1000;
OllamaAsyncResultCallback callback = ollamaAPI.generateAsync(OllamaModelType.LLAMA2, prompt);
while (!callback.isComplete() || !callback.getStream().isEmpty()) {
// poll for data from the response stream
String result = callback.getStream().poll();
if (result != null) {
System.out.print(result);
while (true) {
String tokens = streamer.getStream().poll();
System.out.print(tokens);
if (!streamer.isAlive()) {
break;
}
Thread.sleep(100);
Thread.sleep(pollIntervalMilliseconds);
}
System.out.println("\n------------------------");
System.out.println("Complete Response:");
System.out.println("------------------------");
System.out.println(streamer.getResult());
}
}
```
You will get a response similar to:
> I am LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in a conversational
> manner. I am trained on a massive dataset of text from the internet and can generate human-like responses to a wide
> range of topics and questions. I can be used to create chatbots, virtual assistants, and other applications that
> require
> natural language understanding and generation capabilities.
You will get a steaming response.

View File

@@ -5,8 +5,8 @@ sidebar_position: 4
# Generate - With Image Files
This API lets you ask questions along with the image files to the LLMs.
These APIs correlate to
the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) APIs.
This API corresponds to
the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) API.
:::note

View File

@@ -5,8 +5,8 @@ sidebar_position: 5
# Generate - With Image URLs
This API lets you ask questions along with the image files to the LLMs.
These APIs correlate to
the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) APIs.
This API corresponds to
the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) API.
:::note

View File

@@ -1,12 +1,12 @@
---
sidebar_position: 2
sidebar_position: 3
---
# Generate - With Tools
This API lets you perform [function calling](https://docs.mistral.ai/capabilities/function_calling/) using LLMs in a
synchronous way.
This API correlates to
This API corresponds to
the [generate](https://github.com/ollama/ollama/blob/main/docs/api.md#request-raw-mode) API with `raw` mode.
:::note
@@ -29,8 +29,8 @@ You could do that with ease with the `function calling` capabilities of the mode
### Create Functions
This function takes the arguments `location` and `fuelType` and performs an operation with these arguments and returns a
value.
This function takes the arguments `location` and `fuelType` and performs an operation with these arguments and returns
fuel price value.
```java
public static String getCurrentFuelPrice(Map<String, Object> arguments) {
@@ -40,8 +40,8 @@ public static String getCurrentFuelPrice(Map<String, Object> arguments) {
}
```
This function takes the argument `city` and performs an operation with the argument and returns a
value.
This function takes the argument `city` and performs an operation with the argument and returns the weather for a
location.
```java
public static String getCurrentWeather(Map<String, Object> arguments) {
@@ -50,6 +50,19 @@ public static String getCurrentWeather(Map<String, Object> arguments) {
}
```
This function takes the argument `employee-name` and performs an operation with the argument and returns employee
details.
```java
class DBQueryFunction implements ToolFunction {
@Override
public Object apply(Map<String, Object> arguments) {
// perform DB operations here
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString());
}
}
```
### Define Tool Specifications
Lets define a sample tool specification called **Fuel Price Tool** for getting the current fuel price.
@@ -58,13 +71,13 @@ Lets define a sample tool specification called **Fuel Price Tool** for getting t
- Associate the `getCurrentFuelPrice` function you defined earlier with `SampleTools::getCurrentFuelPrice`.
```java
MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder()
Tools.ToolSpecification fuelPriceToolSpecification = Tools.ToolSpecification.builder()
.functionName("current-fuel-price")
.functionDesc("Get current fuel price")
.props(
new MistralTools.PropsBuilder()
.withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
.functionDescription("Get current fuel price")
.properties(
new Tools.PropsBuilder()
.withProperty("location", Tools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.withProperty("fuelType", Tools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentFuelPrice)
@@ -77,18 +90,38 @@ Lets also define a sample tool specification called **Weather Tool** for getting
- Associate the `getCurrentWeather` function you defined earlier with `SampleTools::getCurrentWeather`.
```java
MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder()
Tools.ToolSpecification weatherToolSpecification = Tools.ToolSpecification.builder()
.functionName("current-weather")
.functionDesc("Get current weather")
.props(
new MistralTools.PropsBuilder()
.withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.functionDescription("Get current weather")
.properties(
new Tools.PropsBuilder()
.withProperty("city", Tools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentWeather)
.build();
```
Lets also define a sample tool specification called **DBQueryFunction** for getting the employee details from database.
- Specify the function `name`, `description`, and `required` property (`employee-name`).
- Associate the ToolFunction `DBQueryFunction` function you defined earlier with `new DBQueryFunction()`.
```java
Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.properties(
new Tools.PropsBuilder()
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
.build()
)
.toolDefinition(new DBQueryFunction())
.build();
```
### Register the Tools
Register the defined tools (`fuel price` and `weather`) with the OllamaAPI.
@@ -103,14 +136,14 @@ ollamaAPI.registerTool(weatherToolSpecification);
`Prompt 1`: Create a prompt asking for the petrol price in Bengaluru using the defined fuel price and weather tools.
```shell
String prompt1 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the petrol price in Bengaluru?")
.build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt1, false, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
String prompt1 = new Tools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the petrol price in Bengaluru?")
.build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt1, new OptionsBuilder().build());
for (OllamaToolsResult.ToolResult r : toolsResult.getToolResults()) {
System.out.printf("[Result of executing tool '%s']: %s%n", r.getFunctionName(), r.getResult().toString());
}
```
@@ -120,21 +153,21 @@ You will get a response similar to:
::::tip[LLM Response]
[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
[Result of executing tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
::::
`Prompt 2`: Create a prompt asking for the current weather in Bengaluru using the same tools.
```shell
String prompt2 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the current weather in Bengaluru?")
.build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt2, false, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
String prompt2 = new Tools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the current weather in Bengaluru?")
.build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt2, new OptionsBuilder().build());
for (OllamaToolsResult.ToolResult r : toolsResult.getToolResults()) {
System.out.printf("[Result of executing tool '%s']: %s%n", r.getFunctionName(), r.getResult().toString());
}
```
@@ -144,25 +177,53 @@ You will get a response similar to:
::::tip[LLM Response]
[Response from tool 'current-weather']: Currently Bengaluru's weather is nice
[Result of executing tool 'current-weather']: Currently Bengaluru's weather is nice.
::::
`Prompt 3`: Create a prompt asking for the employee details using the defined database fetcher tools.
```shell
String prompt3 = new Tools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withToolSpecification(databaseQueryToolSpecification)
.withPrompt("Give me the details of the employee named 'Rahul Kumar'?")
.build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt3, new OptionsBuilder().build());
for (OllamaToolsResult.ToolResult r : toolsResult.getToolResults()) {
System.out.printf("[Result of executing tool '%s']: %s%n", r.getFunctionName(), r.getResult().toString());
}
```
Again, fire away your question to the model.
You will get a response similar to:
::::tip[LLM Response]
[Result of executing tool 'get-employee-details']: Employee Details `{ID: 6bad82e6-b1a1-458f-a139-e3b646e092b1, Name:
Rahul Kumar, Address: King St, Hyderabad, India, Phone: 9876543210}`
::::
### Full Example
```java
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.tools.ToolDef;
import io.github.amithkoujalgi.ollama4j.core.tools.MistralTools;
import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolInvocationException;
import io.github.amithkoujalgi.ollama4j.core.tools.OllamaToolsResult;
import io.github.amithkoujalgi.ollama4j.core.tools.ToolFunction;
import io.github.amithkoujalgi.ollama4j.core.tools.Tools;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.UUID;
public class FunctionCallingWithMistral {
public class FunctionCallingWithMistralExample {
public static void main(String[] args) throws Exception {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
@@ -170,78 +231,113 @@ public class FunctionCallingWithMistral {
String model = "mistral";
MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder()
Tools.ToolSpecification fuelPriceToolSpecification = Tools.ToolSpecification.builder()
.functionName("current-fuel-price")
.functionDesc("Get current fuel price")
.props(
new MistralTools.PropsBuilder()
.withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
.functionDescription("Get current fuel price")
.properties(
new Tools.PropsBuilder()
.withProperty("location", Tools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.withProperty("fuelType", Tools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentFuelPrice)
.build();
MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder()
Tools.ToolSpecification weatherToolSpecification = Tools.ToolSpecification.builder()
.functionName("current-weather")
.functionDesc("Get current weather")
.props(
new MistralTools.PropsBuilder()
.withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.functionDescription("Get current weather")
.properties(
new Tools.PropsBuilder()
.withProperty("city", Tools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentWeather)
.build();
Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.properties(
new Tools.PropsBuilder()
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
.build()
)
.toolDefinition(new DBQueryFunction())
.build();
ollamaAPI.registerTool(fuelPriceToolSpecification);
ollamaAPI.registerTool(weatherToolSpecification);
ollamaAPI.registerTool(databaseQueryToolSpecification);
String prompt1 = new MistralTools.PromptBuilder()
String prompt1 = new Tools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the petrol price in Bengaluru?")
.build();
String prompt2 = new MistralTools.PromptBuilder()
ask(ollamaAPI, model, prompt1);
String prompt2 = new Tools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the current weather in Bengaluru?")
.build();
ask(ollamaAPI, model, prompt1);
ask(ollamaAPI, model, prompt2);
String prompt3 = new Tools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withToolSpecification(databaseQueryToolSpecification)
.withPrompt("Give me the details of the employee named 'Rahul Kumar'?")
.build();
ask(ollamaAPI, model, prompt3);
}
public static void ask(OllamaAPI ollamaAPI, String model, String prompt) throws OllamaBaseException, IOException, InterruptedException {
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt, false, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
public static void ask(OllamaAPI ollamaAPI, String model, String prompt) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt, new OptionsBuilder().build());
for (OllamaToolsResult.ToolResult r : toolsResult.getToolResults()) {
System.out.printf("[Result of executing tool '%s']: %s%n", r.getFunctionName(), r.getResult().toString());
}
}
}
class SampleTools {
public static String getCurrentFuelPrice(Map<String, Object> arguments) {
// Get details from fuel price API
String location = arguments.get("location").toString();
String fuelType = arguments.get("fuelType").toString();
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
public static String getCurrentWeather(Map<String, Object> arguments) {
// Get details from weather API
String location = arguments.get("city").toString();
return "Currently " + location + "'s weather is nice.";
}
}
class DBQueryFunction implements ToolFunction {
@Override
public Object apply(Map<String, Object> arguments) {
// perform DB operations here
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString());
}
}
```
Run this full example and you will get a response similar to:
::::tip[LLM Response]
[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
[Result of executing tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
[Result of executing tool 'current-weather']: Currently Bengaluru's weather is nice.
[Result of executing tool 'get-employee-details']: Employee Details `{ID: 6bad82e6-b1a1-458f-a139-e3b646e092b1, Name:
Rahul Kumar, Address: King St, Hyderabad, India, Phone: 9876543210}`
[Response from tool 'current-weather']: Currently Bengaluru's weather is nice
::::
### Room for improvement

View File

@@ -5,8 +5,8 @@ sidebar_position: 1
# Generate - Sync
This API lets you ask questions to the LLMs in a synchronous way.
These APIs correlate to
the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) APIs.
This API corresponds to
the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) API.
Use the `OptionBuilder` to build the `Options` object
with [extra parameters](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).

View File

@@ -1,68 +0,0 @@
## This workflow will build a package using Maven and then publish it to GitHub packages when a release is created
## For more information see: https://github.com/actions/setup-java/blob/main/docs/advanced-usage.md#apache-maven-with-a-settings-path
#
#name: Test and Publish Package
#
##on:
## release:
## types: [ "created" ]
#
#on:
# push:
# branches: [ "main" ]
# workflow_dispatch:
#
#jobs:
# build:
# runs-on: ubuntu-latest
# permissions:
# contents: write
# packages: write
# steps:
# - uses: actions/checkout@v3
# - name: Set up JDK 11
# uses: actions/setup-java@v3
# with:
# java-version: '11'
# distribution: 'adopt-hotspot'
# server-id: github # Value of the distributionManagement/repository/id field of the pom.xml
# settings-path: ${{ github.workspace }} # location for the settings.xml file
# - name: Build with Maven
# run: mvn --file pom.xml -U clean package -Punit-tests
# - name: Set up Apache Maven Central (Overwrite settings.xml)
# uses: actions/setup-java@v3
# with: # running setup-java again overwrites the settings.xml
# java-version: '11'
# distribution: 'adopt-hotspot'
# cache: 'maven'
# server-id: ossrh
# server-username: MAVEN_USERNAME
# server-password: MAVEN_PASSWORD
# gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }}
# gpg-passphrase: MAVEN_GPG_PASSPHRASE
# - name: Set up Maven cache
# uses: actions/cache@v3
# with:
# path: ~/.m2/repository
# key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }}
# restore-keys: |
# ${{ runner.os }}-maven-
# - name: Build
# run: mvn -B -ntp clean install
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v3
# env:
# CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
# - name: Publish to GitHub Packages Apache Maven
# # if: >
# # github.event_name != 'pull_request' &&
# # github.ref_name == 'main' &&
# # contains(github.event.head_commit.message, 'release')
# run: |
# git config --global user.email "koujalgi.amith@gmail.com"
# git config --global user.name "amithkoujalgi"
# mvn -B -ntp -DskipTests -Pci-cd -Darguments="-DskipTests -Pci-cd" release:clean release:prepare release:perform
# env:
# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
# MAVEN_PASSWORD: ${{ secrets.OSSRH_PASSWORD }}
# MAVEN_GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }}

View File

@@ -1,6 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolInvocationException;
import io.github.amithkoujalgi.ollama4j.core.exceptions.ToolNotFoundException;
import io.github.amithkoujalgi.ollama4j.core.models.*;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
@@ -9,10 +11,12 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.models.request.*;
import io.github.amithkoujalgi.ollama4j.core.tools.*;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import lombok.Setter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -36,10 +40,22 @@ public class OllamaAPI {
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
private final String host;
/**
* -- SETTER --
* Set request timeout in seconds. Default is 3 seconds.
*/
@Setter
private long requestTimeoutSeconds = 10;
/**
* -- SETTER --
* Set/unset logging of responses
*/
@Setter
private boolean verbose = true;
private BasicAuth basicAuth;
private final ToolRegistry toolRegistry = new ToolRegistry();
/**
* Instantiates the Ollama API.
*
@@ -53,24 +69,6 @@ public class OllamaAPI {
}
}
/**
* Set request timeout in seconds. Default is 3 seconds.
*
* @param requestTimeoutSeconds the request timeout in seconds
*/
public void setRequestTimeoutSeconds(long requestTimeoutSeconds) {
this.requestTimeoutSeconds = requestTimeoutSeconds;
}
/**
* Set/unset logging of responses
*
* @param verbose true/false
*/
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
/**
* Set basic authentication for accessing Ollama server that's behind a reverse-proxy/gateway.
*
@@ -360,15 +358,15 @@ public class OllamaAPI {
}
/**
* Convenience method to call Ollama API without streaming responses.
* Generates response using the specified AI model and prompt (in blocking mode).
* <p>
* Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
*
* @param model Model to use
* @param prompt Prompt text
* @param model The name or identifier of the AI model to use for generating the response.
* @param prompt The input text or prompt to provide to the AI model.
* @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context.
* @param options Additional Options
* @return OllamaResult
* @param options Additional options or configurations to use when generating the response.
* @return {@link OllamaResult}
*/
public OllamaResult generate(String model, String prompt, boolean raw, Options options)
throws OllamaBaseException, IOException, InterruptedException {
@@ -376,17 +374,36 @@ public class OllamaAPI {
}
public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options)
throws OllamaBaseException, IOException, InterruptedException {
/**
* Generates response using the specified AI model and prompt (in blocking mode), and then invokes a set of tools
* on the generated response.
*
* @param model The name or identifier of the AI model to use for generating the response.
* @param prompt The input text or prompt to provide to the AI model.
* @param options Additional options or configurations to use when generating the response.
* @return {@link OllamaToolsResult} An OllamaToolsResult object containing the response from the AI model and the results of invoking the tools on that output.
* @throws OllamaBaseException If there is an error related to the Ollama API or service.
* @throws IOException If there is an error related to input/output operations.
* @throws InterruptedException If the method is interrupted while waiting for the AI model
* to generate the response or for the tools to be invoked.
*/
public OllamaToolsResult generateWithTools(String model, String prompt, Options options)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
boolean raw = true;
OllamaToolsResult toolResult = new OllamaToolsResult();
Map<ToolDef, Object> toolResults = new HashMap<>();
Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
OllamaResult result = generate(model, prompt, raw, options, null);
toolResult.setModelResult(result);
List<ToolDef> toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class));
for (ToolDef toolDef : toolDefs) {
toolResults.put(toolDef, invokeTool(toolDef));
String toolsResponse = result.getResponse();
if (toolsResponse.contains("[TOOL_CALLS]")) {
toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
}
List<ToolFunctionCallSpec> toolFunctionCallSpecs = Utils.getObjectMapper().readValue(toolsResponse, Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
}
toolResult.setToolResults(toolResults);
return toolResult;
@@ -402,15 +419,15 @@ public class OllamaAPI {
* @param prompt the prompt/question text
* @return the ollama async result callback handle
*/
public OllamaAsyncResultCallback generateAsync(String model, String prompt, boolean raw) {
public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw) {
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
ollamaRequestModel.setRaw(raw);
URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback =
new OllamaAsyncResultCallback(
OllamaAsyncResultStreamer ollamaAsyncResultStreamer =
new OllamaAsyncResultStreamer(
getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
ollamaAsyncResultCallback.start();
return ollamaAsyncResultCallback;
ollamaAsyncResultStreamer.start();
return ollamaAsyncResultStreamer;
}
/**
@@ -508,7 +525,7 @@ public class OllamaAPI {
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
*
* @param request request object to be sent to the server
* @return
* @return {@link OllamaChatResult}
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
@@ -524,7 +541,7 @@ public class OllamaAPI {
*
* @param request request object to be sent to the server
* @param streamHandler callback handler to handle the last message from stream (caution: all previous messages from stream will be concatenated)
* @return
* @return {@link OllamaChatResult}
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
@@ -541,6 +558,10 @@ public class OllamaAPI {
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
}
public void registerTool(Tools.ToolSpecification toolSpecification) {
toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
}
// technical private methods //
private static String encodeFileToBase64(File file) throws IOException {
@@ -603,22 +624,20 @@ public class OllamaAPI {
}
public void registerTool(MistralTools.ToolSpecification toolSpecification) {
ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
}
private Object invokeTool(ToolDef toolDef) {
private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
try {
String methodName = toolDef.getName();
Map<String, Object> arguments = toolDef.getArguments();
DynamicFunction function = ToolRegistry.getFunction(methodName);
String methodName = toolFunctionCallSpec.getName();
Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
ToolFunction function = toolRegistry.getFunction(methodName);
if (verbose) {
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
}
if (function == null) {
throw new IllegalArgumentException("No such tool: " + methodName);
throw new ToolNotFoundException("No such tool: " + methodName);
}
return function.apply(arguments);
} catch (Exception e) {
e.printStackTrace();
return "Error calling tool: " + e.getMessage();
throw new ToolInvocationException("Failed to invoke tool: " + toolFunctionCallSpec.getName(), e);
}
}
}

View File

@@ -0,0 +1,18 @@
package io.github.amithkoujalgi.ollama4j.core;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Queue;
public class OllamaResultStream extends LinkedList<String> implements Queue<String> {
@Override
public String poll() {
StringBuilder tokens = new StringBuilder();
Iterator<String> iterator = this.listIterator();
while (iterator.hasNext()) {
tokens.append(iterator.next());
iterator.remove();
}
return tokens.toString();
}
}

View File

@@ -1,7 +0,0 @@
package io.github.amithkoujalgi.ollama4j.core;
import java.util.function.Consumer;
public interface OllamaStreamHandler extends Consumer<String>{
void accept(String message);
}

View File

@@ -0,0 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.exceptions;
public class ToolInvocationException extends Exception {
public ToolInvocationException(String s, Exception e) {
super(s, e);
}
}

View File

@@ -0,0 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.exceptions;
public class ToolNotFoundException extends Exception {
public ToolNotFoundException(String s) {
super(s);
}
}

View File

@@ -1,6 +1,6 @@
package io.github.amithkoujalgi.ollama4j.core.impl;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaStreamHandler;
public class ConsoleOutputStreamHandler implements OllamaStreamHandler {
private final StringBuffer response = new StringBuffer();

View File

@@ -1,143 +0,0 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.LinkedList;
import java.util.Queue;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
@Data
@EqualsAndHashCode(callSuper = true)
@SuppressWarnings("unused")
public class OllamaAsyncResultCallback extends Thread {
private final HttpRequest.Builder requestBuilder;
private final OllamaGenerateRequestModel ollamaRequestModel;
private final Queue<String> queue = new LinkedList<>();
private String result;
private boolean isDone;
/**
* -- GETTER -- Returns the status of the request. Indicates if the request was successful or a
* failure. If the request was a failure, the `getResponse()` method will return the error
* message.
*/
@Getter private boolean succeeded;
private long requestTimeoutSeconds;
/**
* -- GETTER -- Returns the HTTP response status code for the request that was made to Ollama
* server.
*/
@Getter private int httpStatusCode;
/** -- GETTER -- Returns the response time in milliseconds. */
@Getter private long responseTime = 0;
public OllamaAsyncResultCallback(
HttpRequest.Builder requestBuilder,
OllamaGenerateRequestModel ollamaRequestModel,
long requestTimeoutSeconds) {
this.requestBuilder = requestBuilder;
this.ollamaRequestModel = ollamaRequestModel;
this.isDone = false;
this.result = "";
this.queue.add("");
this.requestTimeoutSeconds = requestTimeoutSeconds;
}
@Override
public void run() {
HttpClient httpClient = HttpClient.newHttpClient();
try {
long startTime = System.currentTimeMillis();
HttpRequest request =
requestBuilder
.POST(
HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.build();
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
this.httpStatusCode = statusCode;
InputStream responseBodyStream = response.body();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
StringBuilder responseBuffer = new StringBuilder();
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
queue.add(ollamaResponseModel.getError());
responseBuffer.append(ollamaResponseModel.getError());
} else {
OllamaGenerateResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
queue.add(ollamaResponseModel.getResponse());
if (!ollamaResponseModel.isDone()) {
responseBuffer.append(ollamaResponseModel.getResponse());
}
}
}
this.isDone = true;
this.succeeded = true;
this.result = responseBuffer.toString();
long endTime = System.currentTimeMillis();
responseTime = endTime - startTime;
}
if (statusCode != 200) {
throw new OllamaBaseException(this.result);
}
} catch (IOException | InterruptedException | OllamaBaseException e) {
this.isDone = true;
this.succeeded = false;
this.result = "[FAILED] " + e.getMessage();
}
}
/**
* Returns the status of the thread. This does not indicate that the request was successful or a
* failure, rather it is just a status flag to indicate if the thread is active or ended.
*
* @return boolean - status
*/
public boolean isComplete() {
return isDone;
}
/**
* Returns the final completion/response when the execution completes. Does not return intermediate results.
*
* @return String completion/response text
*/
public String getResponse() {
return result;
}
public Queue<String> getStream() {
return queue;
}
public void setRequestTimeoutSeconds(long requestTimeoutSeconds) {
this.requestTimeoutSeconds = requestTimeoutSeconds;
}
}

View File

@@ -0,0 +1,124 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import io.github.amithkoujalgi.ollama4j.core.OllamaResultStream;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
@Data
@EqualsAndHashCode(callSuper = true)
@SuppressWarnings("unused")
public class OllamaAsyncResultStreamer extends Thread {
private final HttpRequest.Builder requestBuilder;
private final OllamaGenerateRequestModel ollamaRequestModel;
private final OllamaResultStream stream = new OllamaResultStream();
private String completeResponse;
/**
* -- GETTER -- Returns the status of the request. Indicates if the request was successful or a
* failure. If the request was a failure, the `getResponse()` method will return the error
* message.
*/
@Getter
private boolean succeeded;
@Setter
private long requestTimeoutSeconds;
/**
* -- GETTER -- Returns the HTTP response status code for the request that was made to Ollama
* server.
*/
@Getter
private int httpStatusCode;
/**
* -- GETTER -- Returns the response time in milliseconds.
*/
@Getter
private long responseTime = 0;
public OllamaAsyncResultStreamer(
HttpRequest.Builder requestBuilder,
OllamaGenerateRequestModel ollamaRequestModel,
long requestTimeoutSeconds) {
this.requestBuilder = requestBuilder;
this.ollamaRequestModel = ollamaRequestModel;
this.completeResponse = "";
this.stream.add("");
this.requestTimeoutSeconds = requestTimeoutSeconds;
}
@Override
public void run() {
ollamaRequestModel.setStream(true);
HttpClient httpClient = HttpClient.newHttpClient();
try {
long startTime = System.currentTimeMillis();
HttpRequest request =
requestBuilder
.POST(
HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.build();
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
this.httpStatusCode = statusCode;
InputStream responseBodyStream = response.body();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
StringBuilder responseBuffer = new StringBuilder();
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
stream.add(ollamaResponseModel.getError());
responseBuffer.append(ollamaResponseModel.getError());
} else {
OllamaGenerateResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
String res = ollamaResponseModel.getResponse();
stream.add(res);
if (!ollamaResponseModel.isDone()) {
responseBuffer.append(res);
}
}
}
this.succeeded = true;
this.completeResponse = responseBuffer.toString();
long endTime = System.currentTimeMillis();
responseTime = endTime - startTime;
}
if (statusCode != 200) {
throw new OllamaBaseException(this.completeResponse);
}
} catch (IOException | InterruptedException | OllamaBaseException e) {
this.succeeded = false;
this.completeResponse = "[FAILED] " + e.getMessage();
}
}
}

View File

@@ -1,10 +1,10 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaStreamHandler;
import java.util.ArrayList;
import java.util.List;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
public class OllamaChatStreamObserver {
private OllamaStreamHandler streamHandler;
@@ -17,12 +17,12 @@ public class OllamaChatStreamObserver {
this.streamHandler = streamHandler;
}
public void notify(OllamaChatResponseModel currentResponsePart){
public void notify(OllamaChatResponseModel currentResponsePart) {
responseParts.add(currentResponsePart);
handleCurrentResponsePart(currentResponsePart);
}
protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart){
protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart) {
message = message + currentResponsePart.getMessage().getContent();
streamHandler.accept(message);
}

View File

@@ -3,8 +3,6 @@ package io.github.amithkoujalgi.ollama4j.core.models.generate;
import java.util.ArrayList;
import java.util.List;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
public class OllamaGenerateStreamObserver {
private OllamaStreamHandler streamHandler;
@@ -17,12 +15,12 @@ public class OllamaGenerateStreamObserver {
this.streamHandler = streamHandler;
}
public void notify(OllamaGenerateResponseModel currentResponsePart){
public void notify(OllamaGenerateResponseModel currentResponsePart) {
responseParts.add(currentResponsePart);
handleCurrentResponsePart(currentResponsePart);
}
protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart){
protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart) {
message = message + currentResponsePart.getResponse();
streamHandler.accept(message);
}

View File

@@ -0,0 +1,7 @@
package io.github.amithkoujalgi.ollama4j.core.models.generate;
import java.util.function.Consumer;
public interface OllamaStreamHandler extends Consumer<String> {
void accept(String message);
}

View File

@@ -1,12 +1,12 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;

View File

@@ -1,12 +1,12 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;

View File

@@ -5,6 +5,8 @@ import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@Data
@@ -12,5 +14,22 @@ import java.util.Map;
@AllArgsConstructor
public class OllamaToolsResult {
private OllamaResult modelResult;
private Map<ToolDef, Object> toolResults;
private Map<ToolFunctionCallSpec, Object> toolResults;
public List<ToolResult> getToolResults() {
List<ToolResult> results = new ArrayList<>();
for (Map.Entry<ToolFunctionCallSpec, Object> r : this.toolResults.entrySet()) {
results.add(new ToolResult(r.getKey().getName(), r.getKey().getArguments(), r.getValue()));
}
return results;
}
@Data
@NoArgsConstructor
@AllArgsConstructor
public static class ToolResult {
private String functionName;
private Map<String, Object> functionArguments;
private Object result;
}
}

View File

@@ -3,6 +3,6 @@ package io.github.amithkoujalgi.ollama4j.core.tools;
import java.util.Map;
@FunctionalInterface
public interface DynamicFunction {
public interface ToolFunction {
Object apply(Map<String, Object> arguments);
}

View File

@@ -9,10 +9,8 @@ import java.util.Map;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ToolDef {
public class ToolFunctionCallSpec {
private String name;
private Map<String, Object> arguments;
}

View File

@@ -4,14 +4,13 @@ import java.util.HashMap;
import java.util.Map;
public class ToolRegistry {
private static final Map<String, DynamicFunction> functionMap = new HashMap<>();
private final Map<String, ToolFunction> functionMap = new HashMap<>();
public static DynamicFunction getFunction(String name) {
public ToolFunction getFunction(String name) {
return functionMap.get(name);
}
public static void addFunction(String name, DynamicFunction function) {
public void addFunction(String name, ToolFunction function) {
functionMap.put(name, function);
}
}

View File

@@ -14,14 +14,14 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class MistralTools {
public class Tools {
@Data
@Builder
public static class ToolSpecification {
private String functionName;
private String functionDesc;
private Map<String, PromptFuncDefinition.Property> props;
private DynamicFunction toolDefinition;
private String functionDescription;
private Map<String, PromptFuncDefinition.Property> properties;
private ToolFunction toolDefinition;
}
@Data
@@ -90,14 +90,14 @@ public class MistralTools {
PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
functionDetail.setName(spec.getFunctionName());
functionDetail.setDescription(spec.getFunctionDesc());
functionDetail.setDescription(spec.getFunctionDescription());
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
parameters.setType("object");
parameters.setProperties(spec.getProps());
parameters.setProperties(spec.getProperties());
List<String> requiredValues = new ArrayList<>();
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProps().entrySet()) {
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) {
if (p.getValue().isRequired()) {
requiredValues.add(p.getKey());
}
@@ -109,31 +109,5 @@ public class MistralTools {
tools.add(def);
return this;
}
//
// public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map<String, PromptFuncDefinition.Property> props) {
// PromptFuncDefinition def = new PromptFuncDefinition();
// def.setType("function");
//
// PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
// functionDetail.setName(functionName);
// functionDetail.setDescription(functionDesc);
//
// PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
// parameters.setType("object");
// parameters.setProperties(props);
//
// List<String> requiredValues = new ArrayList<>();
// for (Map.Entry<String, PromptFuncDefinition.Property> p : props.entrySet()) {
// if (p.getValue().isRequired()) {
// requiredValues.add(p.getKey());
// }
// }
// parameters.setRequired(requiredValues);
// functionDetail.setParameters(parameters);
// def.setFunction(functionDetail);
//
// tools.add(def);
// return this;
// }
}
}

View File

@@ -3,7 +3,7 @@ package io.github.amithkoujalgi.ollama4j.unittests;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultStreamer;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
@@ -157,7 +157,7 @@ class TestMockedAPIs {
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
when(ollamaAPI.generateAsync(model, prompt, false))
.thenReturn(new OllamaAsyncResultCallback(null, null, 3));
.thenReturn(new OllamaAsyncResultStreamer(null, null, 3));
ollamaAPI.generateAsync(model, prompt, false);
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
}