Compare commits

..

4 Commits

28 changed files with 1698 additions and 2238 deletions

View File

@@ -1,41 +1,68 @@
# This workflow will build a package using Maven and then publish it to GitHub packages when a release is created # This workflow will build a package using Maven and then publish it to GitHub packages when a release is created
# For more information see: https://github.com/actions/setup-java/blob/main/docs/advanced-usage.md#apache-maven-with-a-settings-path # For more information see: https://github.com/actions/setup-java/blob/main/docs/advanced-usage.md#apache-maven-with-a-settings-path
name: Release Artifacts name: Test and Publish Package
#on:
# release:
# types: [ "created" ]
on: on:
release: push:
types: [ created ] branches: [ "main" ]
workflow_dispatch:
jobs: jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
contents: read contents: write
packages: write packages: write
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Set up JDK 17 - name: Set up JDK 11
uses: actions/setup-java@v3 uses: actions/setup-java@v3
with: with:
java-version: '17' java-version: '11'
distribution: 'temurin' distribution: 'adopt-hotspot'
server-id: github # Value of the distributionManagement/repository/id field of the pom.xml server-id: github # Value of the distributionManagement/repository/id field of the pom.xml
settings-path: ${{ github.workspace }} # location for the settings.xml file settings-path: ${{ github.workspace }} # location for the settings.xml file
- name: Find and Replace
uses: jacobtomlinson/gha-find-replace@v3
with:
find: "ollama4j-revision"
replace: ${{ github.ref_name }}
regex: false
- name: Build with Maven - name: Build with Maven
run: mvn --file pom.xml -U clean package -Punit-tests run: mvn --file pom.xml -U clean package -Punit-tests
- name: Set up Apache Maven Central (Overwrite settings.xml)
- name: Publish to GitHub Packages Apache Maven uses: actions/setup-java@v3
run: mvn deploy -s $GITHUB_WORKSPACE/settings.xml --file pom.xml with: # running setup-java again overwrites the settings.xml
java-version: '11'
distribution: 'adopt-hotspot'
cache: 'maven'
server-id: ossrh
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Set up Maven cache
uses: actions/cache@v3
with:
path: ~/.m2/repository
key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }}
restore-keys: |
${{ runner.os }}-maven-
- name: Build
run: mvn -B -ntp clean install
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env: env:
GITHUB_TOKEN: ${{ github.token }} CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- name: Publish to GitHub Packages Apache Maven
# if: >
# github.event_name != 'pull_request' &&
# github.ref_name == 'main' &&
# contains(github.event.head_commit.message, 'release')
run: |
git config --global user.email "koujalgi.amith@gmail.com"
git config --global user.name "amithkoujalgi"
mvn -B -ntp -DskipTests -Pci-cd -Darguments="-DskipTests -Pci-cd" release:clean release:prepare release:perform
env:
MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
MAVEN_PASSWORD: ${{ secrets.OSSRH_PASSWORD }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }}

View File

@@ -2,8 +2,9 @@
name: Deploy Docs to GH Pages name: Deploy Docs to GH Pages
on: on:
release: # Runs on pushes targeting the default branch
types: [ created ] push:
branches: [ "main" ]
# Allows you to run this workflow manually from the Actions tab # Allows you to run this workflow manually from the Actions tab
workflow_dispatch: workflow_dispatch:
@@ -46,13 +47,6 @@ jobs:
- run: cd docs && npm ci - run: cd docs && npm ci
- run: cd docs && npm run build - run: cd docs && npm run build
- name: Find and Replace
uses: jacobtomlinson/gha-find-replace@v3
with:
find: "ollama4j-revision"
replace: ${{ github.ref_name }}
regex: false
- name: Build with Maven - name: Build with Maven
run: mvn --file pom.xml -U clean package && cp -r ./target/apidocs/. ./docs/build/apidocs run: mvn --file pom.xml -U clean package && cp -r ./target/apidocs/. ./docs/build/apidocs

52
.github/workflows/publish-javadoc.yml vendored Normal file
View File

@@ -0,0 +1,52 @@
# Simple workflow for deploying static content to GitHub Pages
name: Deploy Javadoc content to Pages
on:
# Runs on pushes targeting the default branch
push:
branches: [ "none" ]
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
permissions:
contents: read
pages: write
id-token: write
packages: write
# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete.
concurrency:
group: "pages"
cancel-in-progress: false
jobs:
# Single deploy job since we're just deploying
deploy:
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- uses: actions/checkout@v3
- name: Set up JDK 11
uses: actions/setup-java@v3
with:
java-version: '11'
distribution: 'adopt-hotspot'
server-id: github # Value of the distributionManagement/repository/id field of the pom.xml
settings-path: ${{ github.workspace }} # location for the settings.xml file
- name: Build with Maven
run: mvn --file pom.xml -U clean package
- name: Setup Pages
uses: actions/configure-pages@v3
- name: Upload artifact
uses: actions/upload-pages-artifact@v2
with:
# Upload entire repository
path: './target/apidocs/.'
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v2

121
README.md
View File

@@ -1,5 +1,3 @@
<div style="text-align: center">
### Ollama4j ### Ollama4j
<img src='https://raw.githubusercontent.com/amithkoujalgi/ollama4j/65a9d526150da8fcd98e2af6a164f055572bf722/ollama4j.jpeg' width='100' alt="ollama4j-icon"> <img src='https://raw.githubusercontent.com/amithkoujalgi/ollama4j/65a9d526150da8fcd98e2af6a164f055572bf722/ollama4j.jpeg' width='100' alt="ollama4j-icon">
@@ -11,36 +9,23 @@ Find more details on the [website](https://amithkoujalgi.github.io/ollama4j/).
![GitHub stars](https://img.shields.io/github/stars/amithkoujalgi/ollama4j) ![GitHub stars](https://img.shields.io/github/stars/amithkoujalgi/ollama4j)
![GitHub forks](https://img.shields.io/github/forks/amithkoujalgi/ollama4j) ![GitHub forks](https://img.shields.io/github/forks/amithkoujalgi/ollama4j)
![GitHub watchers](https://img.shields.io/github/watchers/amithkoujalgi/ollama4j) ![GitHub watchers](https://img.shields.io/github/watchers/amithkoujalgi/ollama4j)
![Contributors](https://img.shields.io/github/contributors/amithkoujalgi/ollama4j)
![GitHub License](https://img.shields.io/github/license/amithkoujalgi/ollama4j)
![GitHub repo size](https://img.shields.io/github/repo-size/amithkoujalgi/ollama4j) ![GitHub repo size](https://img.shields.io/github/repo-size/amithkoujalgi/ollama4j)
![GitHub language count](https://img.shields.io/github/languages/count/amithkoujalgi/ollama4j)
![GitHub top language](https://img.shields.io/github/languages/top/amithkoujalgi/ollama4j) ![GitHub top language](https://img.shields.io/github/languages/top/amithkoujalgi/ollama4j)
![GitHub last commit](https://img.shields.io/github/last-commit/amithkoujalgi/ollama4j?color=green) ![GitHub last commit](https://img.shields.io/github/last-commit/amithkoujalgi/ollama4j?color=green)
![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Famithkoujalgi%2Follama4j&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)
[![codecov](https://codecov.io/gh/amithkoujalgi/ollama4j/graph/badge.svg?token=U0TE7BGP8L)](https://codecov.io/gh/amithkoujalgi/ollama4j) [![codecov](https://codecov.io/gh/amithkoujalgi/ollama4j/graph/badge.svg?token=U0TE7BGP8L)](https://codecov.io/gh/amithkoujalgi/ollama4j)
![GitHub Issues or Pull Requests](https://img.shields.io/github/issues-raw/amithkoujalgi/ollama4j)
![GitHub Issues or Pull Requests](https://img.shields.io/github/issues-closed-raw/amithkoujalgi/ollama4j)
![GitHub Issues or Pull Requests](https://img.shields.io/github/issues-pr-raw/amithkoujalgi/ollama4j)
![GitHub Issues or Pull Requests](https://img.shields.io/github/issues-pr-closed-raw/amithkoujalgi/ollama4j)
![GitHub Discussions](https://img.shields.io/github/discussions/amithkoujalgi/ollama4j)
![Build Status](https://github.com/amithkoujalgi/ollama4j/actions/workflows/maven-publish.yml/badge.svg) ![Build Status](https://github.com/amithkoujalgi/ollama4j/actions/workflows/maven-publish.yml/badge.svg)
</div>
[//]: # (![Hits]&#40;https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Famithkoujalgi%2Follama4j&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false&#41;)
[//]: # (![GitHub language count]&#40;https://img.shields.io/github/languages/count/amithkoujalgi/ollama4j&#41;)
## Table of Contents ## Table of Contents
- [How does it work?](#how-does-it-work) - [How does it work?](#how-does-it-work)
- [Requirements](#requirements) - [Requirements](#requirements)
- [Installation](#installation) - [Installation](#installation)
- [API Spec](https://amithkoujalgi.github.io/ollama4j/docs/category/apis---model-management) - [API Spec](#api-spec)
- [Javadoc](https://amithkoujalgi.github.io/ollama4j/apidocs/) - [Demo APIs](#try-out-the-apis-with-ollama-server)
- [Development](#development) - [Development](#development)
- [Contributions](#get-involved) - [Contributions](#get-involved)
- [References](#references) - [References](#references)
@@ -63,9 +48,9 @@ Find more details on the [website](https://amithkoujalgi.github.io/ollama4j/).
![Java](https://img.shields.io/badge/Java-11_+-green.svg?style=just-the-message&labelColor=gray) ![Java](https://img.shields.io/badge/Java-11_+-green.svg?style=just-the-message&labelColor=gray)
[![][ollama-shield]][ollama-link] **Or** [![][ollama-docker-shield]][ollama-docker] [![][ollama-shield]][ollama] **Or** [![][ollama-docker-shield]][ollama-docker]
[ollama-link]: https://ollama.ai/ [ollama]: https://ollama.ai/
[ollama-shield]: https://img.shields.io/badge/Ollama-Local_Installation-blue.svg?style=just-the-message&labelColor=gray [ollama-shield]: https://img.shields.io/badge/Ollama-Local_Installation-blue.svg?style=just-the-message&labelColor=gray
@@ -75,73 +60,25 @@ Find more details on the [website](https://amithkoujalgi.github.io/ollama4j/).
#### Installation #### Installation
Check the releases [here](https://github.com/amithkoujalgi/ollama4j/releases) and update the dependency version In your Maven project, add this dependency:
according to your requirements.
[![][ollama4j-releases-shield]][ollama4j-releases-link]
[ollama4j-releases-link]: https://github.com/amithkoujalgi/ollama4j/releases
[ollama4j-releases-shield]: https://img.shields.io/github/v/release/amithkoujalgi/ollama4j?include_prereleases&display_name=release&style=for-the-badge&label=Latest%20Release
##### For Maven
1. In your Maven project, add this dependency:
```xml ```xml
<dependency> <dependency>
<groupId>io.github.amithkoujalgi</groupId> <groupId>io.github.amithkoujalgi</groupId>
<artifactId>ollama4j</artifactId> <artifactId>ollama4j</artifactId>
<version>1.0.74</version> <version>1.0.70</version>
</dependency> </dependency>
``` ```
2. Add repository to your project's pom.xml: or
```xml
<repositories>
<repository>
<id>github</id>
<name>GitHub Apache Maven Packages</name>
<url>https://maven.pkg.github.com/amithkoujalgi/ollama4j</url>
<releases>
<enabled>true</enabled>
</releases>
<snapshots>
<enabled>true</enabled>
</snapshots>
</repository>
</repositories>
```
3. Add GitHub server to settings.xml. (Usually available at ~/.m2/settings.xml)
```xml
<settings xmlns="http://maven.apache.org/SETTINGS/1.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/SETTINGS/1.0.0
http://maven.apache.org/xsd/settings-1.0.0.xsd">
<servers>
<server>
<id>github</id>
<username>YOUR-USERNAME</username>
<password>YOUR-TOKEN</password>
</server>
</servers>
</settings>
```
##### For Gradle
In your Gradle project, add the dependency using the Kotlin DSL or the Groovy DSL: In your Gradle project, add the dependency using the Kotlin DSL or the Groovy DSL:
```kotlin ```kotlin
dependencies { dependencies {
val ollama4jVersion = "1.0.74" val ollama4jVersion = "1.0.70"
implementation("io.github.amithkoujalgi:ollama4j:$ollama4jVersion") implementation("io.github.amithkoujalgi:ollama4j:$ollama4jVersion")
} }
@@ -149,19 +86,15 @@ dependencies {
```groovy ```groovy
dependencies { dependencies {
implementation("io.github.amithkoujalgi:ollama4j:1.0.74") implementation("io.github.amithkoujalgi:ollama4j:1.0.70")
} }
``` ```
[//]: # (Latest release:) Latest release:
[//]: # () ![Maven Central](https://img.shields.io/maven-central/v/io.github.amithkoujalgi/ollama4j)
[//]: # (![Maven Central]&#40;https://img.shields.io/maven-central/v/io.github.amithkoujalgi/ollama4j&#41;) [![][lib-shield]][lib]
[//]: # ()
[//]: # ([![][lib-shield]][lib])
[lib]: https://central.sonatype.com/artifact/io.github.amithkoujalgi/ollama4j [lib]: https://central.sonatype.com/artifact/io.github.amithkoujalgi/ollama4j
@@ -220,9 +153,6 @@ Actions CI workflow.
- [x] Use lombok - [x] Use lombok
- [x] Update request body creation with Java objects - [x] Update request body creation with Java objects
- [ ] Async APIs for images - [ ] Async APIs for images
- [ ] Support for function calling with models like Mistral
- [x] generate in sync mode
- [ ] generate in async mode
- [ ] Add custom headers to requests - [ ] Add custom headers to requests
- [x] Add additional params for `ask` APIs such as: - [x] Add additional params for `ask` APIs such as:
- [x] `options`: additional model parameters for the Modelfile such as `temperature` - - [x] `options`: additional model parameters for the Modelfile such as `temperature` -
@@ -242,28 +172,11 @@ Contributions are most welcome! Whether it's reporting a bug, proposing an enhan
with code - any sort with code - any sort
of contribution is much appreciated. of contribution is much appreciated.
### References
- [Ollama REST APIs](https://github.com/jmorganca/ollama/blob/main/docs/api.md)
### Credits ### Credits
The nomenclature and the icon have been adopted from the incredible [Ollama](https://ollama.ai/) The nomenclature and the icon have been adopted from the incredible [Ollama](https://ollama.ai/)
project. project.
### References
<div style="text-align: center"> - [Ollama REST APIs](https://github.com/jmorganca/ollama/blob/main/docs/api.md)
**Thanks to the amazing contributors**
<a href="https://github.com/amithkoujalgi/ollama4j/graphs/contributors">
<img src="https://contrib.rocks/image?repo=amithkoujalgi/ollama4j" />
</a>
### Appreciate my work?
<a href="https://www.buymeacoffee.com/amithkoujalgi" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>
</div>

View File

@@ -1,5 +1,5 @@
--- ---
sidebar_position: 3 sidebar_position: 2
--- ---
# Generate - Async # Generate - Async

View File

@@ -1,5 +1,5 @@
--- ---
sidebar_position: 4 sidebar_position: 3
--- ---
# Generate - With Image Files # Generate - With Image Files

View File

@@ -1,5 +1,5 @@
--- ---
sidebar_position: 5 sidebar_position: 4
--- ---
# Generate - With Image URLs # Generate - With Image URLs

View File

@@ -1,271 +0,0 @@
---
sidebar_position: 2
---
# Generate - With Tools
This API lets you perform [function calling](https://docs.mistral.ai/capabilities/function_calling/) using LLMs in a
synchronous way.
This API correlates to
the [generate](https://github.com/ollama/ollama/blob/main/docs/api.md#request-raw-mode) API with `raw` mode.
:::note
This is an only an experimental implementation and has a very basic design.
Currently, built and tested for [Mistral's latest model](https://ollama.com/library/mistral) only. We could redesign
this
in the future if tooling is supported for more models with a generic interaction standard from Ollama.
:::
### Function Calling/Tools
Assume you want to call a method in your code based on the response generated from the model.
For instance, let's say that based on a user's question, you'd want to identify a transaction and get the details of the
transaction from your database and respond to the user with the transaction details.
You could do that with ease with the `function calling` capabilities of the models by registering your `tools`.
### Create Functions
This function takes the arguments `location` and `fuelType` and performs an operation with these arguments and returns a
value.
```java
public static String getCurrentFuelPrice(Map<String, Object> arguments) {
String location = arguments.get("location").toString();
String fuelType = arguments.get("fuelType").toString();
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
```
This function takes the argument `city` and performs an operation with the argument and returns a
value.
```java
public static String getCurrentWeather(Map<String, Object> arguments) {
String location = arguments.get("city").toString();
return "Currently " + location + "'s weather is nice.";
}
```
### Define Tool Specifications
Lets define a sample tool specification called **Fuel Price Tool** for getting the current fuel price.
- Specify the function `name`, `description`, and `required` properties (`location` and `fuelType`).
- Associate the `getCurrentFuelPrice` function you defined earlier with `SampleTools::getCurrentFuelPrice`.
```java
MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-fuel-price")
.functionDesc("Get current fuel price")
.props(
new MistralTools.PropsBuilder()
.withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentFuelPrice)
.build();
```
Lets also define a sample tool specification called **Weather Tool** for getting the current weather.
- Specify the function `name`, `description`, and `required` property (`city`).
- Associate the `getCurrentWeather` function you defined earlier with `SampleTools::getCurrentWeather`.
```java
MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-weather")
.functionDesc("Get current weather")
.props(
new MistralTools.PropsBuilder()
.withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentWeather)
.build();
```
### Register the Tools
Register the defined tools (`fuel price` and `weather`) with the OllamaAPI.
```shell
ollamaAPI.registerTool(fuelPriceToolSpecification);
ollamaAPI.registerTool(weatherToolSpecification);
```
### Create prompt with Tools
`Prompt 1`: Create a prompt asking for the petrol price in Bengaluru using the defined fuel price and weather tools.
```shell
String prompt1 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the petrol price in Bengaluru?")
.build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt1, false, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
}
```
Now, fire away your question to the model.
You will get a response similar to:
::::tip[LLM Response]
[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
::::
`Prompt 2`: Create a prompt asking for the current weather in Bengaluru using the same tools.
```shell
String prompt2 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the current weather in Bengaluru?")
.build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt2, false, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
}
```
Again, fire away your question to the model.
You will get a response similar to:
::::tip[LLM Response]
[Response from tool 'current-weather']: Currently Bengaluru's weather is nice
::::
### Full Example
```java
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.tools.ToolDef;
import io.github.amithkoujalgi.ollama4j.core.tools.MistralTools;
import io.github.amithkoujalgi.ollama4j.core.tools.OllamaToolsResult;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
public class FunctionCallingWithMistral {
public static void main(String[] args) throws Exception {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.setRequestTimeoutSeconds(60);
String model = "mistral";
MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-fuel-price")
.functionDesc("Get current fuel price")
.props(
new MistralTools.PropsBuilder()
.withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentFuelPrice)
.build();
MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-weather")
.functionDesc("Get current weather")
.props(
new MistralTools.PropsBuilder()
.withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentWeather)
.build();
ollamaAPI.registerTool(fuelPriceToolSpecification);
ollamaAPI.registerTool(weatherToolSpecification);
String prompt1 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the petrol price in Bengaluru?")
.build();
String prompt2 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the current weather in Bengaluru?")
.build();
ask(ollamaAPI, model, prompt1);
ask(ollamaAPI, model, prompt2);
}
public static void ask(OllamaAPI ollamaAPI, String model, String prompt) throws OllamaBaseException, IOException, InterruptedException {
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt, false, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
}
}
}
class SampleTools {
public static String getCurrentFuelPrice(Map<String, Object> arguments) {
String location = arguments.get("location").toString();
String fuelType = arguments.get("fuelType").toString();
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
public static String getCurrentWeather(Map<String, Object> arguments) {
String location = arguments.get("city").toString();
return "Currently " + location + "'s weather is nice.";
}
}
```
Run this full example and you will get a response similar to:
::::tip[LLM Response]
[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
[Response from tool 'current-weather']: Currently Bengaluru's weather is nice
::::
### Room for improvement
Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool
registration. For example:
```java
@ToolSpec(name = "current-fuel-price", desc = "Get current fuel price")
public String getCurrentFuelPrice(Map<String, Object> arguments) {
String location = arguments.get("location").toString();
String fuelType = arguments.get("fuelType").toString();
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
```
Instead of passing a map of args `Map<String, Object> arguments` to the tool functions, we could support passing
specific args separately with their data types. For example:
```shell
public String getCurrentFuelPrice(String location, String fuelType) {
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
```
Updating async/chat APIs with support for tool-based generation.

View File

@@ -11,7 +11,7 @@ the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#gener
Use the `OptionBuilder` to build the `Options` object Use the `OptionBuilder` to build the `Options` object
with [extra parameters](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). with [extra parameters](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
Refer Refer
to [this](/apis-extras/options-builder). to [this](/docs/apis-extras/options-builder).
## Try asking a question about the model. ## Try asking a question about the model.
@@ -53,26 +53,25 @@ public class Main {
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
// define a stream handler (Consumer<String>) // define a stream handler (Consumer<String>)
OllamaStreamHandler streamHandler = (s) -> { OllamaStreamHandler streamHandler = (s) -> {
System.out.println(s); System.out.println(s);
}; };
// Should be called using seperate thread to gain non blocking streaming effect. // Should be called using seperate thread to gain non blocking streaming effect.
OllamaResult result = ollamaAPI.generate(config.getModel(), OllamaResult result = ollamaAPI.generate(config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?", "What is the capital of France? And what's France's connection with Mona Lisa?",
new OptionsBuilder().build(), streamHandler); new OptionsBuilder().build(), streamHandler);
System.out.println("Full response: " + result.getResponse()); System.out.println("Full response: " +result.getResponse());
} }
} }
``` ```
You will get a response similar to: You will get a response similar to:
> The > The
> The capital > The capital
> The capital of > The capital of
> The capital of France > The capital of France
> The capital of France is > The capital of France is
> The capital of France is Paris > The capital of France is Paris
> The capital of France is Paris. > The capital of France is Paris.
> Full response: The capital of France is Paris. > Full response: The capital of France is Paris.

View File

@@ -1,5 +1,5 @@
--- ---
sidebar_position: 6 sidebar_position: 5
--- ---
# Prompt Builder # Prompt Builder

View File

@@ -40,8 +40,6 @@ const config = {
/** @type {import('@docusaurus/preset-classic').Options} */ /** @type {import('@docusaurus/preset-classic').Options} */
({ ({
docs: { docs: {
path: 'docs',
routeBasePath: '', // change this to any URL route you'd want. For example: `home` - if you want /home/intro.
sidebarPath: './sidebars.js', sidebarPath: './sidebars.js',
// Please change this to your repo. // Please change this to your repo.
// Remove this to remove the "edit this page" links. // Remove this to remove the "edit this page" links.
@@ -98,7 +96,7 @@ const config = {
items: [ items: [
{ {
label: 'Tutorial', label: 'Tutorial',
to: '/intro', to: '/docs/intro',
}, },
], ],
}, },

1951
docs/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -14,9 +14,9 @@
"write-heading-ids": "docusaurus write-heading-ids" "write-heading-ids": "docusaurus write-heading-ids"
}, },
"dependencies": { "dependencies": {
"@docusaurus/core": "^3.4.0", "@docusaurus/core": "3.0.1",
"@docusaurus/preset-classic": "^3.4.0", "@docusaurus/preset-classic": "3.0.1",
"@docusaurus/theme-mermaid": "^3.4.0", "@docusaurus/theme-mermaid": "^3.0.1",
"@mdx-js/react": "^3.0.0", "@mdx-js/react": "^3.0.0",
"clsx": "^2.0.0", "clsx": "^2.0.0",
"prism-react-renderer": "^2.3.0", "prism-react-renderer": "^2.3.0",
@@ -24,8 +24,8 @@
"react-dom": "^18.0.0" "react-dom": "^18.0.0"
}, },
"devDependencies": { "devDependencies": {
"@docusaurus/module-type-aliases": "^3.4.0", "@docusaurus/module-type-aliases": "3.0.1",
"@docusaurus/types": "^3.4.0" "@docusaurus/types": "3.0.1"
}, },
"browserslist": { "browserslist": {
"production": [ "production": [

View File

@@ -19,7 +19,7 @@ function HomepageHeader() {
<div className={styles.buttons}> <div className={styles.buttons}>
<Link <Link
className="button button--secondary button--lg" className="button button--secondary button--lg"
to="/intro"> to="/docs/intro">
Getting Started Getting Started
</Link> </Link>
</div> </div>

View File

@@ -1,68 +0,0 @@
## This workflow will build a package using Maven and then publish it to GitHub packages when a release is created
## For more information see: https://github.com/actions/setup-java/blob/main/docs/advanced-usage.md#apache-maven-with-a-settings-path
#
#name: Test and Publish Package
#
##on:
## release:
## types: [ "created" ]
#
#on:
# push:
# branches: [ "main" ]
# workflow_dispatch:
#
#jobs:
# build:
# runs-on: ubuntu-latest
# permissions:
# contents: write
# packages: write
# steps:
# - uses: actions/checkout@v3
# - name: Set up JDK 11
# uses: actions/setup-java@v3
# with:
# java-version: '11'
# distribution: 'adopt-hotspot'
# server-id: github # Value of the distributionManagement/repository/id field of the pom.xml
# settings-path: ${{ github.workspace }} # location for the settings.xml file
# - name: Build with Maven
# run: mvn --file pom.xml -U clean package -Punit-tests
# - name: Set up Apache Maven Central (Overwrite settings.xml)
# uses: actions/setup-java@v3
# with: # running setup-java again overwrites the settings.xml
# java-version: '11'
# distribution: 'adopt-hotspot'
# cache: 'maven'
# server-id: ossrh
# server-username: MAVEN_USERNAME
# server-password: MAVEN_PASSWORD
# gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }}
# gpg-passphrase: MAVEN_GPG_PASSPHRASE
# - name: Set up Maven cache
# uses: actions/cache@v3
# with:
# path: ~/.m2/repository
# key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }}
# restore-keys: |
# ${{ runner.os }}-maven-
# - name: Build
# run: mvn -B -ntp clean install
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v3
# env:
# CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
# - name: Publish to GitHub Packages Apache Maven
# # if: >
# # github.event_name != 'pull_request' &&
# # github.ref_name == 'main' &&
# # contains(github.event.head_commit.message, 'release')
# run: |
# git config --global user.email "koujalgi.amith@gmail.com"
# git config --global user.name "amithkoujalgi"
# mvn -B -ntp -DskipTests -Pci-cd -Darguments="-DskipTests -Pci-cd" release:clean release:prepare release:perform
# env:
# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
# MAVEN_PASSWORD: ${{ secrets.OSSRH_PASSWORD }}
# MAVEN_GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }}

122
pom.xml
View File

@@ -1,16 +1,14 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<groupId>io.github.amithkoujalgi</groupId> <groupId>io.github.amithkoujalgi</groupId>
<artifactId>ollama4j</artifactId> <artifactId>ollama4j</artifactId>
<version>ollama4j-revision</version> <version>1.0.72-1</version>
<name>Ollama4j</name> <name>Ollama4j</name>
<description>Java library for interacting with Ollama API.</description> <description>Java library for interacting with Ollama API.</description>
<url>https://github.com/amithkoujalgi/ollama4j</url> <url>https://github.com/amithkoujalgi/ollama4j</url>
<packaging>jar</packaging>
<properties> <properties>
<maven.compiler.source>11</maven.compiler.source> <maven.compiler.source>11</maven.compiler.source>
@@ -129,18 +127,25 @@
</execution> </execution>
</executions> </executions>
</plugin> </plugin>
<!-- <plugin>--> <plugin>
<!-- <groupId>org.apache.maven.plugins</groupId>--> <groupId>org.apache.maven.plugins</groupId>
<!-- <artifactId>maven-release-plugin</artifactId>--> <artifactId>maven-release-plugin</artifactId>
<!-- <version>3.0.1</version>--> <version>3.0.1</version>
<!-- <configuration>--> <configuration>
<!-- &lt;!&ndash; <goals>install</goals>&ndash;&gt;--> <!-- <goals>install</goals>-->
<!-- <tagNameFormat>v@{project.version}</tagNameFormat>--> <tagNameFormat>v@{project.version}</tagNameFormat>
<!-- </configuration>--> </configuration>
<!-- </plugin>--> </plugin>
</plugins> </plugins>
</build> </build>
<repositories>
<repository>
<id>gitea</id>
<url>https://gitea.seeseepuff.be/api/packages/seeseemelk/maven</url>
</repository>
</repositories>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.projectlombok</groupId> <groupId>org.projectlombok</groupId>
@@ -161,7 +166,7 @@
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId> <artifactId>logback-classic</artifactId>
<version>1.5.6</version> <version>1.4.12</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency> <dependency>
@@ -189,23 +194,14 @@
</dependency> </dependency>
</dependencies> </dependencies>
<!-- <distributionManagement>-->
<!-- <snapshotRepository>-->
<!-- <id>ossrh</id>-->
<!-- <url>https://s01.oss.sonatype.org/content/repositories/snapshots</url>-->
<!-- </snapshotRepository>-->
<!-- <repository>-->
<!-- <id>ossrh</id>-->
<!-- <url>https://s01.oss.sonatype.org/service/local/staging/deploy/maven2</url>-->
<!-- </repository>-->
<!-- </distributionManagement>-->
<!-- Replaced publishing packages to GitHub Packages instead of Maven central -->
<distributionManagement> <distributionManagement>
<snapshotRepository>
<id>gitea</id>
<url>https://gitea.seeseepuff.be/api/packages/seeseemelk/maven</url>
</snapshotRepository>
<repository> <repository>
<id>github</id> <id>gitea</id>
<name>GitHub Packages</name> <url>https://gitea.seeseepuff.be/api/packages/seeseemelk/maven</url>
<url>https://maven.pkg.github.com/amithkoujalgi/ollama4j</url>
</repository> </repository>
</distributionManagement> </distributionManagement>
@@ -261,39 +257,39 @@
</properties> </properties>
<build> <build>
<plugins> <plugins>
<!-- <plugin>--> <plugin>
<!-- <groupId>org.apache.maven.plugins</groupId>--> <groupId>org.apache.maven.plugins</groupId>
<!-- <artifactId>maven-gpg-plugin</artifactId>--> <artifactId>maven-gpg-plugin</artifactId>
<!-- <version>3.1.0</version>--> <version>3.1.0</version>
<!-- <executions>--> <executions>
<!-- <execution>--> <execution>
<!-- <id>sign-artifacts</id>--> <id>sign-artifacts</id>
<!-- <phase>verify</phase>--> <phase>verify</phase>
<!-- <goals>--> <goals>
<!-- <goal>sign</goal>--> <goal>sign</goal>
<!-- </goals>--> </goals>
<!-- <configuration>--> <configuration>
<!-- &lt;!&ndash; Prevent gpg from using pinentry programs. Fixes:--> <!-- Prevent gpg from using pinentry programs. Fixes:
<!-- gpg: signing failed: Inappropriate ioctl for device &ndash;&gt;--> gpg: signing failed: Inappropriate ioctl for device -->
<!-- <gpgArguments>--> <gpgArguments>
<!-- <arg>&#45;&#45;pinentry-mode</arg>--> <arg>--pinentry-mode</arg>
<!-- <arg>loopback</arg>--> <arg>loopback</arg>
<!-- </gpgArguments>--> </gpgArguments>
<!-- </configuration>--> </configuration>
<!-- </execution>--> </execution>
<!-- </executions>--> </executions>
<!-- </plugin>--> </plugin>
<!-- <plugin>--> <plugin>
<!-- <groupId>org.sonatype.plugins</groupId>--> <groupId>org.sonatype.plugins</groupId>
<!-- <artifactId>nexus-staging-maven-plugin</artifactId>--> <artifactId>nexus-staging-maven-plugin</artifactId>
<!-- <version>1.6.13</version>--> <version>1.6.13</version>
<!-- <extensions>true</extensions>--> <extensions>true</extensions>
<!-- <configuration>--> <configuration>
<!-- <serverId>ossrh</serverId>--> <serverId>ossrh</serverId>
<!-- <nexusUrl>https://s01.oss.sonatype.org/</nexusUrl>--> <nexusUrl>https://s01.oss.sonatype.org/</nexusUrl>
<!-- <autoReleaseAfterClose>true</autoReleaseAfterClose>--> <autoReleaseAfterClose>true</autoReleaseAfterClose>
<!-- </configuration>--> </configuration>
<!-- </plugin>--> </plugin>
<plugin> <plugin>
<groupId>org.jacoco</groupId> <groupId>org.jacoco</groupId>
@@ -319,4 +315,4 @@
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@@ -10,7 +10,6 @@ import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingRe
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.request.*; import io.github.amithkoujalgi.ollama4j.core.models.request.*;
import io.github.amithkoujalgi.ollama4j.core.tools.*;
import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
@@ -26,7 +25,9 @@ import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Files; import java.nio.file.Files;
import java.time.Duration; import java.time.Duration;
import java.util.*; import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
/** /**
* The base Ollama API class. * The base Ollama API class.
@@ -338,7 +339,6 @@ public class OllamaAPI {
} }
} }
/** /**
* Generate response for a question to a model running on Ollama server. This is a sync/blocking * Generate response for a question to a model running on Ollama server. This is a sync/blocking
* call. * call.
@@ -351,10 +351,9 @@ public class OllamaAPI {
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false. * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
* @return OllamaResult that includes response text and time taken for response * @return OllamaResult that includes response text and time taken for response
*/ */
public OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler) public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
} }
@@ -362,37 +361,13 @@ public class OllamaAPI {
/** /**
* Convenience method to call Ollama API without streaming responses. * Convenience method to call Ollama API without streaming responses.
* <p> * <p>
* Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)} * Uses {@link #generate(String, String, Options, OllamaStreamHandler)}
*
* @param model Model to use
* @param prompt Prompt text
* @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context.
* @param options Additional Options
* @return OllamaResult
*/ */
public OllamaResult generate(String model, String prompt, boolean raw, Options options) public OllamaResult generate(String model, String prompt, Options options)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
return generate(model, prompt, raw, options, null); return generate(model, prompt, options, null);
} }
public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options)
throws OllamaBaseException, IOException, InterruptedException {
OllamaToolsResult toolResult = new OllamaToolsResult();
Map<ToolDef, Object> toolResults = new HashMap<>();
OllamaResult result = generate(model, prompt, raw, options, null);
toolResult.setModelResult(result);
List<ToolDef> toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class));
for (ToolDef toolDef : toolDefs) {
toolResults.put(toolDef, invokeTool(toolDef));
}
toolResult.setToolResults(toolResults);
return toolResult;
}
/** /**
* Generate response for a question to a model running on Ollama server and get a callback handle * Generate response for a question to a model running on Ollama server and get a callback handle
* that can be used to check for status and get the response from the model later. This would be * that can be used to check for status and get the response from the model later. This would be
@@ -402,9 +377,9 @@ public class OllamaAPI {
* @param prompt the prompt/question text * @param prompt the prompt/question text
* @return the ollama async result callback handle * @return the ollama async result callback handle
*/ */
public OllamaAsyncResultCallback generateAsync(String model, String prompt, boolean raw) { public OllamaAsyncResultCallback generateAsync(String model, String prompt) {
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
ollamaRequestModel.setRaw(raw);
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback = OllamaAsyncResultCallback ollamaAsyncResultCallback =
new OllamaAsyncResultCallback( new OllamaAsyncResultCallback(
@@ -601,24 +576,4 @@ public class OllamaAPI {
private boolean isBasicAuthCredentialsSet() { private boolean isBasicAuthCredentialsSet() {
return basicAuth != null; return basicAuth != null;
} }
public void registerTool(MistralTools.ToolSpecification toolSpecification) {
ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
}
private Object invokeTool(ToolDef toolDef) {
try {
String methodName = toolDef.getName();
Map<String, Object> arguments = toolDef.getArguments();
DynamicFunction function = ToolRegistry.getFunction(methodName);
if (function == null) {
throw new IllegalArgumentException("No such tool: " + methodName);
}
return function.apply(arguments);
} catch (Exception e) {
e.printStackTrace();
return "Error calling tool: " + e.getMessage();
}
}
} }

View File

@@ -16,16 +16,10 @@ public class OllamaChatResult extends OllamaResult{
List<OllamaChatMessage> chatHistory) { List<OllamaChatMessage> chatHistory) {
super(response, responseTime, httpStatusCode); super(response, responseTime, httpStatusCode);
this.chatHistory = chatHistory; this.chatHistory = chatHistory;
appendAnswerToChatHistory(response);
} }
public List<OllamaChatMessage> getChatHistory() { public List<OllamaChatMessage> getChatHistory() {
return chatHistory; return chatHistory;
}
private void appendAnswerToChatHistory(String answer){
OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer);
this.chatHistory.add(assistantMessage);
} }

View File

@@ -11,8 +11,6 @@ public class OllamaChatStreamObserver {
private List<OllamaChatResponseModel> responseParts = new ArrayList<>(); private List<OllamaChatResponseModel> responseParts = new ArrayList<>();
private String message = "";
public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) { public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) {
this.streamHandler = streamHandler; this.streamHandler = streamHandler;
} }
@@ -23,8 +21,7 @@ public class OllamaChatStreamObserver {
} }
protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart){ protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart){
message = message + currentResponsePart.getMessage().getContent(); streamHandler.accept(currentResponsePart.getMessage().getContent());
streamHandler.accept(message);
} }

View File

@@ -1,5 +1,9 @@
package io.github.amithkoujalgi.ollama4j.core.models.request; package io.github.amithkoujalgi.ollama4j.core.models.request;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler; import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
@@ -9,19 +13,15 @@ import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRespo
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException; public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
private OllamaGenerateStreamObserver streamObserver; private OllamaGenerateStreamObserver streamObserver;
public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose); super(host, basicAuth, requestTimeoutSeconds, verbose);
} }
@Override @Override
@@ -31,22 +31,24 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
@Override @Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
try { try {
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
responseBuffer.append(ollamaResponseModel.getResponse()); responseBuffer.append(ollamaResponseModel.getResponse());
if (streamObserver != null) { if(streamObserver != null) {
streamObserver.notify(ollamaResponseModel); streamObserver.notify(ollamaResponseModel);
} }
return ollamaResponseModel.isDone(); return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!", e); LOG.error("Error parsing the Ollama chat response!",e);
return true; return true;
} }
} }
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaGenerateStreamObserver(streamHandler); streamObserver = new OllamaGenerateStreamObserver(streamHandler);
return super.callSync(body); return super.callSync(body);
} }
} }

View File

@@ -1,8 +0,0 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import java.util.Map;
@FunctionalInterface
public interface DynamicFunction {
Object apply(Map<String, Object> arguments);
}

View File

@@ -1,139 +0,0 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import lombok.Builder;
import lombok.Data;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class MistralTools {
@Data
@Builder
public static class ToolSpecification {
private String functionName;
private String functionDesc;
private Map<String, PromptFuncDefinition.Property> props;
private DynamicFunction toolDefinition;
}
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public static class PromptFuncDefinition {
private String type;
private PromptFuncSpec function;
@Data
public static class PromptFuncSpec {
private String name;
private String description;
private Parameters parameters;
}
@Data
public static class Parameters {
private String type;
private Map<String, Property> properties;
private List<String> required;
}
@Data
@Builder
public static class Property {
private String type;
private String description;
@JsonProperty("enum")
@JsonInclude(JsonInclude.Include.NON_NULL)
private List<String> enumValues;
@JsonIgnore
private boolean required;
}
}
public static class PropsBuilder {
private final Map<String, PromptFuncDefinition.Property> props = new HashMap<>();
public PropsBuilder withProperty(String key, PromptFuncDefinition.Property property) {
props.put(key, property);
return this;
}
public Map<String, PromptFuncDefinition.Property> build() {
return props;
}
}
public static class PromptBuilder {
private final List<PromptFuncDefinition> tools = new ArrayList<>();
private String promptText;
public String build() throws JsonProcessingException {
return "[AVAILABLE_TOOLS] " + Utils.getObjectMapper().writeValueAsString(tools) + "[/AVAILABLE_TOOLS][INST] " + promptText + " [/INST]";
}
public PromptBuilder withPrompt(String prompt) throws JsonProcessingException {
promptText = prompt;
return this;
}
public PromptBuilder withToolSpecification(ToolSpecification spec) {
PromptFuncDefinition def = new PromptFuncDefinition();
def.setType("function");
PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
functionDetail.setName(spec.getFunctionName());
functionDetail.setDescription(spec.getFunctionDesc());
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
parameters.setType("object");
parameters.setProperties(spec.getProps());
List<String> requiredValues = new ArrayList<>();
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProps().entrySet()) {
if (p.getValue().isRequired()) {
requiredValues.add(p.getKey());
}
}
parameters.setRequired(requiredValues);
functionDetail.setParameters(parameters);
def.setFunction(functionDetail);
tools.add(def);
return this;
}
//
// public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map<String, PromptFuncDefinition.Property> props) {
// PromptFuncDefinition def = new PromptFuncDefinition();
// def.setType("function");
//
// PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
// functionDetail.setName(functionName);
// functionDetail.setDescription(functionDesc);
//
// PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
// parameters.setType("object");
// parameters.setProperties(props);
//
// List<String> requiredValues = new ArrayList<>();
// for (Map.Entry<String, PromptFuncDefinition.Property> p : props.entrySet()) {
// if (p.getValue().isRequired()) {
// requiredValues.add(p.getKey());
// }
// }
// parameters.setRequired(requiredValues);
// functionDetail.setParameters(parameters);
// def.setFunction(functionDetail);
//
// tools.add(def);
// return this;
// }
}
}

View File

@@ -1,16 +0,0 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Map;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class OllamaToolsResult {
private OllamaResult modelResult;
private Map<ToolDef, Object> toolResults;
}

View File

@@ -1,18 +0,0 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Map;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ToolDef {
private String name;
private Map<String, Object> arguments;
}

View File

@@ -1,17 +0,0 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import java.util.HashMap;
import java.util.Map;
public class ToolRegistry {
private static final Map<String, DynamicFunction> functionMap = new HashMap<>();
public static DynamicFunction getFunction(String name) {
return functionMap.get(name);
}
public static void addFunction(String name, DynamicFunction function) {
functionMap.put(name, function);
}
}

View File

@@ -9,9 +9,6 @@ package io.github.amithkoujalgi.ollama4j.core.types;
@SuppressWarnings("ALL") @SuppressWarnings("ALL")
public class OllamaModelType { public class OllamaModelType {
public static final String GEMMA = "gemma"; public static final String GEMMA = "gemma";
public static final String GEMMA2 = "gemma2";
public static final String LLAMA2 = "llama2"; public static final String LLAMA2 = "llama2";
public static final String LLAMA3 = "llama3"; public static final String LLAMA3 = "llama3";
public static final String MISTRAL = "mistral"; public static final String MISTRAL = "mistral";
@@ -33,8 +30,6 @@ public class OllamaModelType {
public static final String ZEPHYR = "zephyr"; public static final String ZEPHYR = "zephyr";
public static final String OPENHERMES = "openhermes"; public static final String OPENHERMES = "openhermes";
public static final String QWEN = "qwen"; public static final String QWEN = "qwen";
public static final String QWEN2 = "qwen2";
public static final String WIZARDCODER = "wizardcoder"; public static final String WIZARDCODER = "wizardcoder";
public static final String LLAMA2_CHINESE = "llama2-chinese"; public static final String LLAMA2_CHINESE = "llama2-chinese";
public static final String TINYLLAMA = "tinyllama"; public static final String TINYLLAMA = "tinyllama";
@@ -84,5 +79,4 @@ public class OllamaModelType {
public static final String NOTUS = "notus"; public static final String NOTUS = "notus";
public static final String DUCKDB_NSQL = "duckdb-nsql"; public static final String DUCKDB_NSQL = "duckdb-nsql";
public static final String ALL_MINILM = "all-minilm"; public static final String ALL_MINILM = "all-minilm";
public static final String CODESTRAL = "codestral";
} }

View File

@@ -1,5 +1,7 @@
package io.github.amithkoujalgi.ollama4j.integrationtests; package io.github.amithkoujalgi.ollama4j.integrationtests;
import static org.junit.jupiter.api.Assertions.*;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
@@ -8,16 +10,9 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import lombok.Data;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@@ -27,369 +22,372 @@ import java.net.http.HttpConnectTimeoutException;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Properties; import java.util.Properties;
import lombok.Data;
import static org.junit.jupiter.api.Assertions.*; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class TestRealAPIs { class TestRealAPIs {
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class); private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
OllamaAPI ollamaAPI; OllamaAPI ollamaAPI;
Config config; Config config;
private File getImageFileFromClasspath(String fileName) { private File getImageFileFromClasspath(String fileName) {
ClassLoader classLoader = getClass().getClassLoader(); ClassLoader classLoader = getClass().getClassLoader();
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
}
@BeforeEach
void setUp() {
config = new Config();
ollamaAPI = new OllamaAPI(config.getOllamaURL());
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
}
@Test
@Order(1)
void testWrongEndpoint() {
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
assertThrows(ConnectException.class, ollamaAPI::listModels);
}
@Test
@Order(1)
void testEndpointReachability() {
try {
assertNotNull(ollamaAPI.listModels());
} catch (HttpConnectTimeoutException e) {
fail(e.getMessage());
} catch (Exception e) {
fail(e);
} }
}
@BeforeEach @Test
void setUp() { @Order(2)
config = new Config(); void testListModels() {
ollamaAPI = new OllamaAPI(config.getOllamaURL()); testEndpointReachability();
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds()); try {
assertNotNull(ollamaAPI.listModels());
ollamaAPI.listModels().forEach(System.out::println);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
} }
}
@Test @Test
@Order(1) @Order(2)
void testWrongEndpoint() { void testPullModel() {
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434"); testEndpointReachability();
assertThrows(ConnectException.class, ollamaAPI::listModels); try {
ollamaAPI.pullModel(config.getModel());
boolean found =
ollamaAPI.listModels().stream()
.anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
assertTrue(found);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
} }
}
@Test @Test
@Order(1) @Order(3)
void testEndpointReachability() { void testListDtails() {
try { testEndpointReachability();
assertNotNull(ollamaAPI.listModels()); try {
} catch (HttpConnectTimeoutException e) { ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel());
fail(e.getMessage()); assertNotNull(modelDetails);
} catch (Exception e) { System.out.println(modelDetails);
fail(e); } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
} fail(e);
} }
}
@Test @Test
@Order(2) @Order(3)
void testListModels() { void testAskModelWithDefaultOptions() {
testEndpointReachability(); testEndpointReachability();
try { try {
assertNotNull(ollamaAPI.listModels()); OllamaResult result =
ollamaAPI.listModels().forEach(System.out::println); ollamaAPI.generate(
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { config.getModel(),
fail(e); "What is the capital of France? And what's France's connection with Mona Lisa?",
} new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
} }
}
@Test @Test
@Order(2) @Order(3)
void testPullModel() { void testAskModelWithDefaultOptionsStreamed() {
testEndpointReachability(); testEndpointReachability();
try { try {
ollamaAPI.pullModel(config.getModel());
boolean found = StringBuffer sb = new StringBuffer("");
ollamaAPI.listModels().stream()
.anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel())); OllamaResult result = ollamaAPI.generate(config.getModel(),
assertTrue(found); "What is the capital of France? And what's France's connection with Mona Lisa?",
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { new OptionsBuilder().build(), (s) -> {
fail(e); LOG.info(s);
} String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
} }
}
@Test @Test
@Order(3) @Order(3)
void testListDtails() { void testAskModelWithOptions() {
testEndpointReachability(); testEndpointReachability();
try { try {
ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel()); OllamaResult result =
assertNotNull(modelDetails); ollamaAPI.generate(
System.out.println(modelDetails); config.getModel(),
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { "What is the capital of France? And what's France's connection with Mona Lisa?",
fail(e); new OptionsBuilder().setTemperature(0.9f).build());
} assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
} }
}
@Test @Test
@Order(3) @Order(3)
void testAskModelWithDefaultOptions() { void testChat() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaResult result = OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
ollamaAPI.generate( OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
config.getModel(), .withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
"What is the capital of France? And what's France's connection with Mona Lisa?", .withMessage(OllamaChatMessageRole.USER,"And what is the second larges city?")
false, .build();
new OptionsBuilder().build());
assertNotNull(result); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(result.getResponse()); assertNotNull(chatResult);
assertFalse(result.getResponse().isEmpty()); assertFalse(chatResult.getResponse().isBlank());
} catch (IOException | OllamaBaseException | InterruptedException e) { assertEquals(4,chatResult.getChatHistory().size());
fail(e); } catch (IOException | OllamaBaseException | InterruptedException e) {
} fail(e);
} }
}
@Test @Test
@Order(3) @Order(3)
void testAskModelWithDefaultOptionsStreamed() { void testChatWithSystemPrompt() {
testEndpointReachability(); testEndpointReachability();
try { try {
StringBuffer sb = new StringBuffer(""); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaResult result = ollamaAPI.generate(config.getModel(), OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"What is the capital of France? And what's France's connection with Mona Lisa?", "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
false, .withMessage(OllamaChatMessageRole.USER,
new OptionsBuilder().build(), (s) -> { "What is the capital of France? And what's France's connection with Mona Lisa?")
LOG.info(s); .build();
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(result.getResponse()); assertNotNull(chatResult);
assertFalse(result.getResponse().isEmpty()); assertFalse(chatResult.getResponse().isBlank());
assertEquals(sb.toString().trim(), result.getResponse().trim()); assertTrue(chatResult.getResponse().startsWith("NI"));
} catch (IOException | OllamaBaseException | InterruptedException e) { assertEquals(3, chatResult.getChatHistory().size());
fail(e); } catch (IOException | OllamaBaseException | InterruptedException e) {
} fail(e);
} }
}
@Test @Test
@Order(3) @Order(3)
void testAskModelWithOptions() { void testChatWithStream() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaResult result = OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
ollamaAPI.generate( OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
config.getModel(), "What is the capital of France? And what's France's connection with Mona Lisa?")
"What is the capital of France? And what's France's connection with Mona Lisa?", .build();
true,
new OptionsBuilder().setTemperature(0.9f).build()); StringBuffer sb = new StringBuffer("");
assertNotNull(result);
assertNotNull(result.getResponse()); OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> {
assertFalse(result.getResponse().isEmpty()); LOG.info(s);
} catch (IOException | OllamaBaseException | InterruptedException e) { String substring = s.substring(sb.toString().length(), s.length());
fail(e); LOG.info(substring);
} sb.append(substring);
});
assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
} }
}
@Test @Test
@Order(3) @Order(3)
void testChat() { void testChatWithImageFromFileWithHistoryRecognition() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); OllamaChatRequestBuilder builder =
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?") OllamaChatRequestBuilder.getInstance(config.getImageModel());
.withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!") OllamaChatRequestModel requestModel =
.withMessage(OllamaChatMessageRole.USER, "And what is the second larges city?") builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
.build(); List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank()); assertNotNull(chatResult.getResponse());
assertEquals(4, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) { builder.reset();
fail(e);
} requestModel =
builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponse());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
} }
}
@Test @Test
@Order(3) @Order(3)
void testChatWithSystemPrompt() { void testChatWithImageFromURL() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
"You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!") "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.withMessage(OllamaChatMessageRole.USER, .build();
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank()); } catch (IOException | OllamaBaseException | InterruptedException e) {
assertTrue(chatResult.getResponse().startsWith("NI")); fail(e);
assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
} }
}
@Test @Test
@Order(3) @Order(3)
void testChatWithStream() { void testAskModelWithOptionsAndImageFiles() {
testEndpointReachability(); testEndpointReachability();
try { File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); try {
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, OllamaResult result =
"What is the capital of France? And what's France's connection with Mona Lisa?") ollamaAPI.generateWithImageFiles(
.build(); config.getImageModel(),
"What is in this image?",
StringBuffer sb = new StringBuffer(""); List.of(imageFile),
new OptionsBuilder().build());
OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> { assertNotNull(result);
LOG.info(s); assertNotNull(result.getResponse());
String substring = s.substring(sb.toString().length(), s.length()); assertFalse(result.getResponse().isEmpty());
LOG.info(substring); } catch (IOException | OllamaBaseException | InterruptedException e) {
sb.append(substring); fail(e);
});
assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
} }
}
@Test @Test
@Order(3) @Order(3)
void testChatWithImageFromFileWithHistoryRecognition() { void testAskModelWithOptionsAndImageFilesStreamed() {
testEndpointReachability(); testEndpointReachability();
try { File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
OllamaChatRequestBuilder builder = try {
OllamaChatRequestBuilder.getInstance(config.getImageModel()); StringBuffer sb = new StringBuffer("");
OllamaChatRequestModel requestModel =
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
assertNotNull(chatResult); "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
assertNotNull(chatResult.getResponse()); LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
builder.reset(); LOG.info(substring);
sb.append(substring);
requestModel = });
builder.withMessages(chatResult.getChatHistory()) assertNotNull(result);
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
chatResult = ollamaAPI.chat(requestModel); assertEquals(sb.toString().trim(), result.getResponse().trim());
assertNotNull(chatResult); } catch (IOException | OllamaBaseException | InterruptedException e) {
assertNotNull(chatResult.getResponse()); fail(e);
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
} }
}
@Test @Test
@Order(3) @Order(3)
void testChatWithImageFromURL() { void testAskModelWithOptionsAndImageURLs() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel()); OllamaResult result =
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", ollamaAPI.generateWithImageURLs(
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") config.getImageModel(),
.build(); "What is in this image?",
List.of(
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
assertNotNull(chatResult); new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException e) { assertNotNull(result);
fail(e); assertNotNull(result.getResponse());
} assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
} }
}
@Test @Test
@Order(3) @Order(3)
void testAskModelWithOptionsAndImageFiles() { public void testEmbedding() {
testEndpointReachability(); testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); try {
try { OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
OllamaResult result = .getInstance(config.getModel(), "What is the capital of France?").build();
ollamaAPI.generateWithImageFiles(
config.getImageModel(), List<Double> embeddings = ollamaAPI.generateEmbeddings(request);
"What is in this image?",
List.of(imageFile), assertNotNull(embeddings);
new OptionsBuilder().build()); assertFalse(embeddings.isEmpty());
assertNotNull(result); } catch (IOException | OllamaBaseException | InterruptedException e) {
assertNotNull(result.getResponse()); fail(e);
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageFilesStreamed() {
testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
try {
StringBuffer sb = new StringBuffer("");
OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
"What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageURLs() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generateWithImageURLs(
config.getImageModel(),
"What is in this image?",
List.of(
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
}
}
@Test
@Order(3)
public void testEmbedding() {
testEndpointReachability();
try {
OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
.getInstance(config.getModel(), "What is the capital of France?").build();
List<Double> embeddings = ollamaAPI.generateEmbeddings(request);
assertNotNull(embeddings);
assertFalse(embeddings.isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
} }
}
} }
@Data @Data
class Config { class Config {
private String ollamaURL; private String ollamaURL;
private String model; private String model;
private String imageModel; private String imageModel;
private int requestTimeoutSeconds; private int requestTimeoutSeconds;
public Config() { public Config() {
Properties properties = new Properties(); Properties properties = new Properties();
try (InputStream input = try (InputStream input =
getClass().getClassLoader().getResourceAsStream("test-config.properties")) { getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
if (input == null) { if (input == null) {
throw new RuntimeException("Sorry, unable to find test-config.properties"); throw new RuntimeException("Sorry, unable to find test-config.properties");
} }
properties.load(input); properties.load(input);
this.ollamaURL = properties.getProperty("ollama.url"); this.ollamaURL = properties.getProperty("ollama.url");
this.model = properties.getProperty("ollama.model"); this.model = properties.getProperty("ollama.model");
this.imageModel = properties.getProperty("ollama.model.image"); this.imageModel = properties.getProperty("ollama.model.image");
this.requestTimeoutSeconds = this.requestTimeoutSeconds =
Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds")); Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException("Error loading properties", e); throw new RuntimeException("Error loading properties", e);
}
} }
}
} }

View File

@@ -1,5 +1,7 @@
package io.github.amithkoujalgi.ollama4j.unittests; package io.github.amithkoujalgi.ollama4j.unittests;
import static org.mockito.Mockito.*;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
@@ -7,158 +9,155 @@ import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import java.io.IOException; import java.io.IOException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import org.junit.jupiter.api.Test;
import static org.mockito.Mockito.*; import org.mockito.Mockito;
class TestMockedAPIs { class TestMockedAPIs {
@Test @Test
void testPullModel() { void testPullModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
try { try {
doNothing().when(ollamaAPI).pullModel(model); doNothing().when(ollamaAPI).pullModel(model);
ollamaAPI.pullModel(model); ollamaAPI.pullModel(model);
verify(ollamaAPI, times(1)).pullModel(model); verify(ollamaAPI, times(1)).pullModel(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testListModels() { void testListModels() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
try { try {
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>()); when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
ollamaAPI.listModels(); ollamaAPI.listModels();
verify(ollamaAPI, times(1)).listModels(); verify(ollamaAPI, times(1)).listModels();
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testCreateModel() { void testCreateModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros."; String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
try { try {
doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath); doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
ollamaAPI.createModelWithModelFileContents(model, modelFilePath); ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath); verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testDeleteModel() { void testDeleteModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
try { try {
doNothing().when(ollamaAPI).deleteModel(model, true); doNothing().when(ollamaAPI).deleteModel(model, true);
ollamaAPI.deleteModel(model, true); ollamaAPI.deleteModel(model, true);
verify(ollamaAPI, times(1)).deleteModel(model, true); verify(ollamaAPI, times(1)).deleteModel(model, true);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testGetModelDetails() { void testGetModelDetails() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
try { try {
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail()); when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
ollamaAPI.getModelDetails(model); ollamaAPI.getModelDetails(model);
verify(ollamaAPI, times(1)).getModelDetails(model); verify(ollamaAPI, times(1)).getModelDetails(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testGenerateEmbeddings() { void testGenerateEmbeddings() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
try { try {
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>()); when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
ollamaAPI.generateEmbeddings(model, prompt); ollamaAPI.generateEmbeddings(model, prompt);
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt); verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testAsk() { void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
OptionsBuilder optionsBuilder = new OptionsBuilder(); OptionsBuilder optionsBuilder = new OptionsBuilder();
try { try {
when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build())) when(ollamaAPI.generate(model, prompt, optionsBuilder.build()))
.thenReturn(new OllamaResult("", 0, 200)); .thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generate(model, prompt, false, optionsBuilder.build()); ollamaAPI.generate(model, prompt, optionsBuilder.build());
verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build()); verify(ollamaAPI, times(1)).generate(model, prompt, optionsBuilder.build());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testAskWithImageFiles() { void testAskWithImageFiles() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
try { try {
when(ollamaAPI.generateWithImageFiles( when(ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build())) model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200)); .thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generateWithImageFiles( ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()); model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1)) verify(ollamaAPI, times(1))
.generateWithImageFiles( .generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()); model, prompt, Collections.emptyList(), new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testAskWithImageURLs() { void testAskWithImageURLs() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
try { try {
when(ollamaAPI.generateWithImageURLs( when(ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build())) model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200)); .thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generateWithImageURLs( ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()); model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1)) verify(ollamaAPI, times(1))
.generateWithImageURLs( .generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()); model, prompt, Collections.emptyList(), new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testAskAsync() { void testAskAsync() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
when(ollamaAPI.generateAsync(model, prompt, false)) when(ollamaAPI.generateAsync(model, prompt))
.thenReturn(new OllamaAsyncResultCallback(null, null, 3)); .thenReturn(new OllamaAsyncResultCallback(null, null, 3));
ollamaAPI.generateAsync(model, prompt, false); ollamaAPI.generateAsync(model, prompt);
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false); verify(ollamaAPI, times(1)).generateAsync(model, prompt);
} }
} }