This commit is contained in:
Amith Koujalgi 2024-07-13 17:27:43 +05:30
commit 38f1bda105
35 changed files with 2240 additions and 1427 deletions

View File

@ -1,68 +1,68 @@
# This workflow will build a package using Maven and then publish it to GitHub packages when a release is created
# For more information see: https://github.com/actions/setup-java/blob/main/docs/advanced-usage.md#apache-maven-with-a-settings-path
name: Test and Publish Package
## This workflow will build a package using Maven and then publish it to GitHub packages when a release is created
## For more information see: https://github.com/actions/setup-java/blob/main/docs/advanced-usage.md#apache-maven-with-a-settings-path
#
#name: Test and Publish Package
#
##on:
## release:
## types: [ "created" ]
#
#on:
# release:
# types: [ "created" ]
on:
push:
branches: [ "main" ]
workflow_dispatch:
jobs:
build:
runs-on: ubuntu-latest
permissions:
contents: write
packages: write
steps:
- uses: actions/checkout@v3
- name: Set up JDK 11
uses: actions/setup-java@v3
with:
java-version: '11'
distribution: 'adopt-hotspot'
server-id: github # Value of the distributionManagement/repository/id field of the pom.xml
settings-path: ${{ github.workspace }} # location for the settings.xml file
- name: Build with Maven
run: mvn --file pom.xml -U clean package -Punit-tests
- name: Set up Apache Maven Central (Overwrite settings.xml)
uses: actions/setup-java@v3
with: # running setup-java again overwrites the settings.xml
java-version: '11'
distribution: 'adopt-hotspot'
cache: 'maven'
server-id: ossrh
server-username: MAVEN_USERNAME
server-password: MAVEN_PASSWORD
gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Set up Maven cache
uses: actions/cache@v3
with:
path: ~/.m2/repository
key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }}
restore-keys: |
${{ runner.os }}-maven-
- name: Build
run: mvn -B -ntp clean install
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- name: Publish to GitHub Packages Apache Maven
# if: >
# github.event_name != 'pull_request' &&
# github.ref_name == 'main' &&
# contains(github.event.head_commit.message, 'release')
run: |
git config --global user.email "koujalgi.amith@gmail.com"
git config --global user.name "amithkoujalgi"
mvn -B -ntp -DskipTests -Pci-cd -Darguments="-DskipTests -Pci-cd" release:clean release:prepare release:perform
env:
MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
MAVEN_PASSWORD: ${{ secrets.OSSRH_PASSWORD }}
MAVEN_GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }}
# push:
# branches: [ "main" ]
# workflow_dispatch:
#
#jobs:
# build:
# runs-on: ubuntu-latest
# permissions:
# contents: write
# packages: write
# steps:
# - uses: actions/checkout@v3
# - name: Set up JDK 11
# uses: actions/setup-java@v3
# with:
# java-version: '11'
# distribution: 'adopt-hotspot'
# server-id: github # Value of the distributionManagement/repository/id field of the pom.xml
# settings-path: ${{ github.workspace }} # location for the settings.xml file
# - name: Build with Maven
# run: mvn --file pom.xml -U clean package -Punit-tests
# - name: Set up Apache Maven Central (Overwrite settings.xml)
# uses: actions/setup-java@v3
# with: # running setup-java again overwrites the settings.xml
# java-version: '11'
# distribution: 'adopt-hotspot'
# cache: 'maven'
# server-id: ossrh
# server-username: MAVEN_USERNAME
# server-password: MAVEN_PASSWORD
# gpg-private-key: ${{ secrets.GPG_PRIVATE_KEY }}
# gpg-passphrase: MAVEN_GPG_PASSPHRASE
# - name: Set up Maven cache
# uses: actions/cache@v3
# with:
# path: ~/.m2/repository
# key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }}
# restore-keys: |
# ${{ runner.os }}-maven-
# - name: Build
# run: mvn -B -ntp clean install
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v3
# env:
# CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
# - name: Publish to GitHub Packages Apache Maven
# # if: >
# # github.event_name != 'pull_request' &&
# # github.ref_name == 'main' &&
# # contains(github.event.head_commit.message, 'release')
# run: |
# git config --global user.email "koujalgi.amith@gmail.com"
# git config --global user.name "amithkoujalgi"
# mvn -B -ntp -DskipTests -Pci-cd -Darguments="-DskipTests -Pci-cd" release:clean release:prepare release:perform
# env:
# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
# MAVEN_PASSWORD: ${{ secrets.OSSRH_PASSWORD }}
# MAVEN_GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }}

View File

@ -1,52 +1,52 @@
# Simple workflow for deploying static content to GitHub Pages
name: Deploy Javadoc content to Pages
on:
# Runs on pushes targeting the default branch
push:
branches: [ "none" ]
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
permissions:
contents: read
pages: write
id-token: write
packages: write
# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete.
concurrency:
group: "pages"
cancel-in-progress: false
jobs:
# Single deploy job since we're just deploying
deploy:
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- uses: actions/checkout@v3
- name: Set up JDK 11
uses: actions/setup-java@v3
with:
java-version: '11'
distribution: 'adopt-hotspot'
server-id: github # Value of the distributionManagement/repository/id field of the pom.xml
settings-path: ${{ github.workspace }} # location for the settings.xml file
- name: Build with Maven
run: mvn --file pom.xml -U clean package
- name: Setup Pages
uses: actions/configure-pages@v3
- name: Upload artifact
uses: actions/upload-pages-artifact@v2
with:
# Upload entire repository
path: './target/apidocs/.'
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v2
## Simple workflow for deploying static content to GitHub Pages
#name: Deploy Javadoc content to Pages
#
#on:
# # Runs on pushes targeting the default branch
# push:
# branches: [ "none" ]
#
# # Allows you to run this workflow manually from the Actions tab
# workflow_dispatch:
#
## Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
#permissions:
# contents: read
# pages: write
# id-token: write
# packages: write
## Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
## However, do NOT cancel in-progress runs as we want to allow these production deployments to complete.
#concurrency:
# group: "pages"
# cancel-in-progress: false
#
#jobs:
# # Single deploy job since we're just deploying
# deploy:
# runs-on: ubuntu-latest
#
# environment:
# name: github-pages
# url: ${{ steps.deployment.outputs.page_url }}
# steps:
# - uses: actions/checkout@v3
# - name: Set up JDK 11
# uses: actions/setup-java@v3
# with:
# java-version: '11'
# distribution: 'adopt-hotspot'
# server-id: github # Value of the distributionManagement/repository/id field of the pom.xml
# settings-path: ${{ github.workspace }} # location for the settings.xml file
# - name: Build with Maven
# run: mvn --file pom.xml -U clean package
# - name: Setup Pages
# uses: actions/configure-pages@v3
# - name: Upload artifact
# uses: actions/upload-pages-artifact@v2
# with:
# # Upload entire repository
# path: './target/apidocs/.'
# - name: Deploy to GitHub Pages
# id: deployment
# uses: actions/deploy-pages@v2

View File

@ -1,3 +1,5 @@
<div style="text-align: center">
### Ollama4j
<img src='https://raw.githubusercontent.com/amithkoujalgi/ollama4j/65a9d526150da8fcd98e2af6a164f055572bf722/ollama4j.jpeg' width='100' alt="ollama4j-icon">
@ -9,16 +11,28 @@ Find more details on the [website](https://amithkoujalgi.github.io/ollama4j/).
![GitHub stars](https://img.shields.io/github/stars/amithkoujalgi/ollama4j)
![GitHub forks](https://img.shields.io/github/forks/amithkoujalgi/ollama4j)
![GitHub watchers](https://img.shields.io/github/watchers/amithkoujalgi/ollama4j)
![Contributors](https://img.shields.io/github/contributors/amithkoujalgi/ollama4j)
![GitHub License](https://img.shields.io/github/license/amithkoujalgi/ollama4j)
![GitHub repo size](https://img.shields.io/github/repo-size/amithkoujalgi/ollama4j)
![GitHub language count](https://img.shields.io/github/languages/count/amithkoujalgi/ollama4j)
![GitHub top language](https://img.shields.io/github/languages/top/amithkoujalgi/ollama4j)
![GitHub last commit](https://img.shields.io/github/last-commit/amithkoujalgi/ollama4j?color=green)
![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Famithkoujalgi%2Follama4j&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)
[![codecov](https://codecov.io/gh/amithkoujalgi/ollama4j/graph/badge.svg?token=U0TE7BGP8L)](https://codecov.io/gh/amithkoujalgi/ollama4j)
![GitHub Issues or Pull Requests](https://img.shields.io/github/issues-raw/amithkoujalgi/ollama4j)
![GitHub Issues or Pull Requests](https://img.shields.io/github/issues-closed-raw/amithkoujalgi/ollama4j)
![GitHub Issues or Pull Requests](https://img.shields.io/github/issues-pr-raw/amithkoujalgi/ollama4j)
![GitHub Issues or Pull Requests](https://img.shields.io/github/issues-pr-closed-raw/amithkoujalgi/ollama4j)
![GitHub Discussions](https://img.shields.io/github/discussions/amithkoujalgi/ollama4j)
![Build Status](https://github.com/amithkoujalgi/ollama4j/actions/workflows/maven-publish.yml/badge.svg)
</div>
[//]: # (![Hits]&#40;https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Famithkoujalgi%2Follama4j&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false&#41;)
[//]: # (![GitHub language count]&#40;https://img.shields.io/github/languages/count/amithkoujalgi/ollama4j&#41;)
## Table of Contents
- [How does it work?](#how-does-it-work)
@ -67,10 +81,29 @@ In your Maven project, add this dependency:
<dependency>
<groupId>io.github.amithkoujalgi</groupId>
<artifactId>ollama4j</artifactId>
<version>1.0.47</version>
<version>1.0.70</version>
</dependency>
```
or
In your Gradle project, add the dependency using the Kotlin DSL or the Groovy DSL:
```kotlin
dependencies {
val ollama4jVersion = "1.0.70"
implementation("io.github.amithkoujalgi:ollama4j:$ollama4jVersion")
}
```
```groovy
dependencies {
implementation("io.github.amithkoujalgi:ollama4j:1.0.70")
}
```
Latest release:
![Maven Central](https://img.shields.io/maven-central/v/io.github.amithkoujalgi/ollama4j)
@ -110,6 +143,16 @@ make it
Releases (newer artifact versions) are done automatically on pushing the code to the `main` branch through GitHub
Actions CI workflow.
#### Who's using Ollama4j?
- `Datafaker`: a library to generate fake data
- https://github.com/datafaker-net/datafaker-experimental/tree/main/ollama-api
- `Vaadin Web UI`: UI-Tester for Interactions with Ollama via ollama4j
- https://github.com/TEAMPB/ollama4j-vaadin-ui
- `ollama-translator`: Minecraft 1.20.6 spigot plugin allows to easily break language barriers by using ollama on the
server to translate all messages into a specfic target language.
- https://github.com/liebki/ollama-translator
#### Traction
[![Star History Chart](https://api.star-history.com/svg?repos=amithkoujalgi/ollama4j&type=Date)](https://star-history.com/#amithkoujalgi/ollama4j&Date)
@ -125,15 +168,15 @@ Actions CI workflow.
- [x] Update request body creation with Java objects
- [ ] Async APIs for images
- [ ] Add custom headers to requests
- [ ] Add additional params for `ask` APIs such as:
- [x] Add additional params for `ask` APIs such as:
- [x] `options`: additional model parameters for the Modelfile such as `temperature` -
Supported [params](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
- [ ] `system`: system prompt to (overrides what is defined in the Modelfile)
- [ ] `template`: the full prompt or prompt template (overrides what is defined in the Modelfile)
- [ ] `context`: the context parameter returned from a previous request, which can be used to keep a
- [x] `system`: system prompt to (overrides what is defined in the Modelfile)
- [x] `template`: the full prompt or prompt template (overrides what is defined in the Modelfile)
- [x] `context`: the context parameter returned from a previous request, which can be used to keep a
short
conversational memory
- [ ] `stream`: Add support for streaming responses from the model
- [x] `stream`: Add support for streaming responses from the model
- [ ] Add test cases
- [ ] Handle exceptions better (maybe throw more appropriate exceptions)
@ -143,11 +186,28 @@ Contributions are most welcome! Whether it's reporting a bug, proposing an enhan
with code - any sort
of contribution is much appreciated.
### References
- [Ollama REST APIs](https://github.com/jmorganca/ollama/blob/main/docs/api.md)
### Credits
The nomenclature and the icon have been adopted from the incredible [Ollama](https://ollama.ai/)
project.
### References
- [Ollama REST APIs](https://github.com/jmorganca/ollama/blob/main/docs/api.md)
<div style="text-align: center">
**Thanks to the amazing contributors**
<a href="https://github.com/amithkoujalgi/ollama4j/graphs/contributors">
<img src="https://contrib.rocks/image?repo=amithkoujalgi/ollama4j" />
</a>
### Appreciate my work?
<a href="https://www.buymeacoffee.com/amithkoujalgi" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>
</div>

View File

@ -20,8 +20,8 @@ public class Main {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(OllamaModelType.LLAMA2);
// create first user question
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,"What is the capital of France?")
.build();
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
.build();
// start conversation with model
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
@ -29,7 +29,7 @@ public class Main {
System.out.println("First answer: " + chatResult.getResponse());
// create next userQuestion
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER,"And what is the second largest city?").build();
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is the second largest city?").build();
// "continue" conversation with model
chatResult = ollamaAPI.chat(requestModel);
@ -41,6 +41,7 @@ public class Main {
}
```
You will get a response similar to:
> First answer: Should be Paris!
@ -50,23 +51,28 @@ You will get a response similar to:
> Chat History:
```json
[ {
"role" : "user",
"content" : "What is the capital of France?",
"images" : [ ]
}, {
"role" : "assistant",
"content" : "Should be Paris!",
"images" : [ ]
}, {
"role" : "user",
"content" : "And what is the second largest city?",
"images" : [ ]
}, {
"role" : "assistant",
"content" : "Marseille.",
"images" : [ ]
} ]
[
{
"role": "user",
"content": "What is the capital of France?",
"images": []
},
{
"role": "assistant",
"content": "Should be Paris!",
"images": []
},
{
"role": "user",
"content": "And what is the second largest city?",
"images": []
},
{
"role": "assistant",
"content": "Marseille.",
"images": []
}
]
```
## Create a conversation where the answer is streamed
@ -81,18 +87,19 @@ public class Main {
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
// define a handler (Consumer<String>)
OllamaStreamHandler streamHandler = (s) -> {
System.out.println(s);
System.out.println(s);
};
OllamaChatResult chatResult = ollamaAPI.chat(requestModel,streamHandler);
OllamaChatResult chatResult = ollamaAPI.chat(requestModel, streamHandler);
}
}
```
You will get a response similar to:
> The
@ -103,8 +110,27 @@ You will get a response similar to:
> The capital of France is Paris
> The capital of France is Paris.
## Use a simple Console Output Stream Handler
```java
import io.github.amithkoujalgi.ollama4j.core.impl.ConsoleOutputStreamHandler;
public class Main {
public static void main(String[] args) throws Exception {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(OllamaModelType.LLAMA2);
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "List all cricket world cup teams of 2019. Name the teams!")
.build();
OllamaStreamHandler streamHandler = new ConsoleOutputStreamHandler();
ollamaAPI.chat(requestModel, streamHandler);
}
}
```
## Create a new conversation with individual system prompt
```java
public class Main {
@ -117,8 +143,8 @@ public class Main {
// create request with system-prompt (overriding the model defaults) and user question
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
.withMessage(OllamaChatMessageRole.USER,"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
// start conversation with model
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
@ -128,6 +154,7 @@ public class Main {
}
```
You will get a response similar to:
> NI.
@ -139,34 +166,40 @@ public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(OllamaModelType.LLAVA);
OllamaAPI ollamaAPI = new OllamaAPI(host);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(OllamaModelType.LLAVA);
// Load Image from File and attach to user message (alternatively images could also be added via URL)
OllamaChatRequestModel requestModel =
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
// Load Image from File and attach to user message (alternatively images could also be added via URL)
OllamaChatRequestModel requestModel =
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
System.out.println("First answer: " + chatResult.getResponse());
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
System.out.println("First answer: " + chatResult.getResponse());
builder.reset();
builder.reset();
// Use history to ask further questions about the image or assistant answer
requestModel =
builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
// Use history to ask further questions about the image or assistant answer
requestModel =
builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
chatResult = ollamaAPI.chat(requestModel);
System.out.println("Second answer: " + chatResult.getResponse());
chatResult = ollamaAPI.chat(requestModel);
System.out.println("Second answer: " + chatResult.getResponse());
}
}
```
You will get a response similar to:
> First Answer: The image shows a dog sitting on the bow of a boat that is docked in calm water. The boat has two levels, with the lower level containing seating and what appears to be an engine cover. The dog seems relaxed and comfortable on the boat, looking out over the water. The background suggests it might be late afternoon or early evening, given the warm lighting and the low position of the sun in the sky.
> First Answer: The image shows a dog sitting on the bow of a boat that is docked in calm water. The boat has two
> levels, with the lower level containing seating and what appears to be an engine cover. The dog seems relaxed and
> comfortable on the boat, looking out over the water. The background suggests it might be late afternoon or early
> evening, given the warm lighting and the low position of the sun in the sky.
>
> Second Answer: Based on the image, it's difficult to definitively determine the breed of the dog. However, the dog appears to be medium-sized with a short coat and a brown coloration, which might suggest that it is a Golden Retriever or a similar breed. Without more details like ear shape and tail length, it's not possible to identify the exact breed confidently.
> Second Answer: Based on the image, it's difficult to definitively determine the breed of the dog. However, the dog
> appears to be medium-sized with a short coat and a brown coloration, which might suggest that it is a Golden Retriever
> or a similar breed. Without more details like ear shape and tail length, it's not possible to identify the exact breed
> confidently.

View File

@ -1,5 +1,5 @@
---
sidebar_position: 2
sidebar_position: 3
---
# Generate - Async

View File

@ -1,5 +1,5 @@
---
sidebar_position: 3
sidebar_position: 4
---
# Generate - With Image Files

View File

@ -1,5 +1,5 @@
---
sidebar_position: 4
sidebar_position: 5
---
# Generate - With Image URLs

View File

@ -0,0 +1,271 @@
---
sidebar_position: 2
---
# Generate - With Tools
This API lets you perform [function calling](https://docs.mistral.ai/capabilities/function_calling/) using LLMs in a
synchronous way.
This API correlates to
the [generate](https://github.com/ollama/ollama/blob/main/docs/api.md#request-raw-mode) API with `raw` mode.
:::note
This is an only an experimental implementation and has a very basic design.
Currently, built and tested for [Mistral's latest model](https://ollama.com/library/mistral) only. We could redesign
this
in the future if tooling is supported for more models with a generic interaction standard from Ollama.
:::
### Function Calling/Tools
Assume you want to call a method in your code based on the response generated from the model.
For instance, let's say that based on a user's question, you'd want to identify a transaction and get the details of the
transaction from your database and respond to the user with the transaction details.
You could do that with ease with the `function calling` capabilities of the models by registering your `tools`.
### Create Functions
This function takes the arguments `location` and `fuelType` and performs an operation with these arguments and returns a
value.
```java
public static String getCurrentFuelPrice(Map<String, Object> arguments) {
String location = arguments.get("location").toString();
String fuelType = arguments.get("fuelType").toString();
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
```
This function takes the argument `city` and performs an operation with the argument and returns a
value.
```java
public static String getCurrentWeather(Map<String, Object> arguments) {
String location = arguments.get("city").toString();
return "Currently " + location + "'s weather is nice.";
}
```
### Define Tool Specifications
Lets define a sample tool specification called **Fuel Price Tool** for getting the current fuel price.
- Specify the function `name`, `description`, and `required` properties (`location` and `fuelType`).
- Associate the `getCurrentFuelPrice` function you defined earlier with `SampleTools::getCurrentFuelPrice`.
```java
MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-fuel-price")
.functionDesc("Get current fuel price")
.props(
new MistralTools.PropsBuilder()
.withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentFuelPrice)
.build();
```
Lets also define a sample tool specification called **Weather Tool** for getting the current weather.
- Specify the function `name`, `description`, and `required` property (`city`).
- Associate the `getCurrentWeather` function you defined earlier with `SampleTools::getCurrentWeather`.
```java
MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-weather")
.functionDesc("Get current weather")
.props(
new MistralTools.PropsBuilder()
.withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentWeather)
.build();
```
### Register the Tools
Register the defined tools (`fuel price` and `weather`) with the OllamaAPI.
```shell
ollamaAPI.registerTool(fuelPriceToolSpecification);
ollamaAPI.registerTool(weatherToolSpecification);
```
### Create prompt with Tools
`Prompt 1`: Create a prompt asking for the petrol price in Bengaluru using the defined fuel price and weather tools.
```shell
String prompt1 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the petrol price in Bengaluru?")
.build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt1, false, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
}
```
Now, fire away your question to the model.
You will get a response similar to:
::::tip[LLM Response]
[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
::::
`Prompt 2`: Create a prompt asking for the current weather in Bengaluru using the same tools.
```shell
String prompt2 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the current weather in Bengaluru?")
.build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt2, false, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
}
```
Again, fire away your question to the model.
You will get a response similar to:
::::tip[LLM Response]
[Response from tool 'current-weather']: Currently Bengaluru's weather is nice
::::
### Full Example
```java
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.tools.ToolDef;
import io.github.amithkoujalgi.ollama4j.core.tools.MistralTools;
import io.github.amithkoujalgi.ollama4j.core.tools.OllamaToolsResult;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
public class FunctionCallingWithMistral {
public static void main(String[] args) throws Exception {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.setRequestTimeoutSeconds(60);
String model = "mistral";
MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-fuel-price")
.functionDesc("Get current fuel price")
.props(
new MistralTools.PropsBuilder()
.withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentFuelPrice)
.build();
MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-weather")
.functionDesc("Get current weather")
.props(
new MistralTools.PropsBuilder()
.withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentWeather)
.build();
ollamaAPI.registerTool(fuelPriceToolSpecification);
ollamaAPI.registerTool(weatherToolSpecification);
String prompt1 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the petrol price in Bengaluru?")
.build();
String prompt2 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the current weather in Bengaluru?")
.build();
ask(ollamaAPI, model, prompt1);
ask(ollamaAPI, model, prompt2);
}
public static void ask(OllamaAPI ollamaAPI, String model, String prompt) throws OllamaBaseException, IOException, InterruptedException {
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt, false, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
}
}
}
class SampleTools {
public static String getCurrentFuelPrice(Map<String, Object> arguments) {
String location = arguments.get("location").toString();
String fuelType = arguments.get("fuelType").toString();
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
public static String getCurrentWeather(Map<String, Object> arguments) {
String location = arguments.get("city").toString();
return "Currently " + location + "'s weather is nice.";
}
}
```
Run this full example and you will get a response similar to:
::::tip[LLM Response]
[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
[Response from tool 'current-weather']: Currently Bengaluru's weather is nice
::::
### Room for improvement
Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool
registration. For example:
```java
@ToolSpec(name = "current-fuel-price", desc = "Get current fuel price")
public String getCurrentFuelPrice(Map<String, Object> arguments) {
String location = arguments.get("location").toString();
String fuelType = arguments.get("fuelType").toString();
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
```
Instead of passing a map of args `Map<String, Object> arguments` to the tool functions, we could support passing
specific args separately with their data types. For example:
```shell
public String getCurrentFuelPrice(String location, String fuelType) {
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
```
Updating async/chat APIs with support for tool-based generation.

View File

@ -1,5 +1,5 @@
---
sidebar_position: 5
sidebar_position: 6
---
# Prompt Builder
@ -42,7 +42,7 @@ public class AskPhi {
.addSeparator()
.add("How do I read a file in Go and print its contents to stdout?");
OllamaResult response = ollamaAPI.generate(model, promptBuilder.build());
OllamaResult response = ollamaAPI.generate(model, promptBuilder.build(), new OptionsBuilder().build());
System.out.println(response.getResponse());
}
}

11
pom.xml
View File

@ -4,7 +4,7 @@
<groupId>io.github.amithkoujalgi</groupId>
<artifactId>ollama4j</artifactId>
<version>1.0.57-SNAPSHOT</version>
<version>1.0.78-SNAPSHOT</version>
<name>Ollama4j</name>
<description>Java library for interacting with Ollama API.</description>
@ -99,7 +99,7 @@
<configuration>
<skipTests>${skipUnitTests}</skipTests>
<includes>
<include>**/unittests/*.java</include>
<include>**/unittests/**/*.java</include>
</includes>
</configuration>
</plugin>
@ -149,7 +149,12 @@
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.15.3</version>
<version>2.17.1</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
<version>2.17.1</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>

View File

@ -0,0 +1,14 @@
package io.github.amithkoujalgi.ollama4j.core.impl;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
public class ConsoleOutputStreamHandler implements OllamaStreamHandler {
private final StringBuffer response = new StringBuffer();
@Override
public void accept(String message) {
String substr = message.substring(response.length());
response.append(substr);
System.out.print(substr);
}
}

View File

@ -1,5 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import java.time.LocalDateTime;
import java.time.OffsetDateTime;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
@ -11,7 +14,9 @@ public class Model {
private String name;
private String model;
@JsonProperty("modified_at")
private String modifiedAt;
private OffsetDateTime modifiedAt;
@JsonProperty("expires_at")
private OffsetDateTime expiresAt;
private String digest;
private long size;
@JsonProperty("details")

View File

@ -1,14 +1,15 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class OllamaChatResponseModel {
private String model;
private @JsonProperty("created_at") String createdAt;
private @JsonProperty("done_reason") String doneReason;
private OllamaChatMessage message;
private boolean done;
private String error;

View File

@ -1,4 +1,4 @@
package io.github.amithkoujalgi.ollama4j.core.models;
package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
import com.fasterxml.jackson.annotation.JsonProperty;
@ -7,7 +7,7 @@ import lombok.Data;
@SuppressWarnings("unused")
@Data
public class EmbeddingResponse {
public class OllamaEmbeddingResponseModel {
@JsonProperty("embedding")
private List<Double> embedding;
}

View File

@ -0,0 +1,31 @@
package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
public class OllamaEmbeddingsRequestBuilder {
private OllamaEmbeddingsRequestBuilder(String model, String prompt){
request = new OllamaEmbeddingsRequestModel(model, prompt);
}
private OllamaEmbeddingsRequestModel request;
public static OllamaEmbeddingsRequestBuilder getInstance(String model, String prompt){
return new OllamaEmbeddingsRequestBuilder(model, prompt);
}
public OllamaEmbeddingsRequestModel build(){
return request;
}
public OllamaEmbeddingsRequestBuilder withOptions(Options options){
this.request.setOptions(options.getOptionsMap());
return this;
}
public OllamaEmbeddingsRequestBuilder withKeepAlive(String keepAlive){
this.request.setKeepAlive(keepAlive);
return this;
}
}

View File

@ -0,0 +1,33 @@
package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import java.util.Map;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
@Data
@RequiredArgsConstructor
@NoArgsConstructor
public class OllamaEmbeddingsRequestModel {
@NonNull
private String model;
@NonNull
private String prompt;
protected Map<String, Object> options;
@JsonProperty(value = "keep_alive")
private String keepAlive;
@Override
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -1,23 +0,0 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class ModelEmbeddingsRequest {
private String model;
private String prompt;
@Override
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -1,12 +1,6 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
@ -15,11 +9,15 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
/**
* Specialization class for requests
*/
public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
@ -39,12 +37,12 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
try {
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
responseBuffer.append(ollamaResponseModel.getMessage().getContent());
if(streamObserver != null) {
if (streamObserver != null) {
streamObserver.notify(ollamaResponseModel);
}
return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!",e);
LOG.error("Error parsing the Ollama chat response!", e);
return true;
}
}
@ -54,7 +52,4 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
streamObserver = new OllamaChatStreamObserver(streamHandler);
return super.callSync(body);
}
}

View File

@ -1,5 +1,15 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaErrorResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
@ -12,17 +22,6 @@ import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaErrorResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
/**
* Abstract helperclass to call the ollama api server.
*/
@ -52,104 +51,102 @@ public abstract class OllamaEndpointCaller {
*
* @param body POST body payload
* @return result answer given by the assistant
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
// Create Request
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + getEndpointSuffix());
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
if (this.verbose) LOG.info("Asking model: " + body.toString());
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + getEndpointSuffix());
HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri)
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
if (this.verbose) LOG.info("Asking model: " + body.toString());
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)");
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)");
OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
boolean finished = parseResponseAndAddToBuffer(line,responseBuffer);
if (finished) {
break;
InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)");
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)");
OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError());
} else {
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
if (finished) {
break;
}
}
}
}
}
}
if (statusCode != 200) {
LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
if (verbose) LOG.info("Model response: " + ollamaResult);
return ollamaResult;
if (statusCode != 200) {
LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString());
} else {
long endTime = System.currentTimeMillis();
OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
if (verbose) LOG.info("Model response: " + ollamaResult);
return ollamaResult;
}
}
}
/**
* Get default request builder.
*
* @param uri URI to get a HttpRequest.Builder
* @return HttpRequest.Builder
*/
private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder =
HttpRequest.newBuilder(uri)
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(this.requestTimeoutSeconds));
if (isBasicAuthCredentialsSet()) {
requestBuilder.header("Authorization", getBasicAuthHeaderValue());
* Get default request builder.
*
* @param uri URI to get a HttpRequest.Builder
* @return HttpRequest.Builder
*/
private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder =
HttpRequest.newBuilder(uri)
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(this.requestTimeoutSeconds));
if (isBasicAuthCredentialsSet()) {
requestBuilder.header("Authorization", getBasicAuthHeaderValue());
}
return requestBuilder;
}
return requestBuilder;
}
/**
* Get basic authentication header value.
*
* @return basic authentication header value (encoded credentials)
*/
private String getBasicAuthHeaderValue() {
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
}
/**
* Get basic authentication header value.
*
* @return basic authentication header value (encoded credentials)
*/
private String getBasicAuthHeaderValue() {
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
}
/**
* Check if Basic Auth credentials set.
*
* @return true when Basic Auth credentials set
*/
private boolean isBasicAuthCredentialsSet() {
return this.basicAuth != null;
}
/**
* Check if Basic Auth credentials set.
*
* @return true when Basic Auth credentials set
*/
private boolean isBasicAuthCredentialsSet() {
return this.basicAuth != null;
}
}

View File

@ -1,9 +1,5 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
@ -13,8 +9,12 @@ import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRespo
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
import java.io.IOException;
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
@ -31,24 +31,22 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
@Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
try {
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
responseBuffer.append(ollamaResponseModel.getResponse());
if(streamObserver != null) {
streamObserver.notify(ollamaResponseModel);
}
return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!",e);
return true;
}
try {
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
responseBuffer.append(ollamaResponseModel.getResponse());
if (streamObserver != null) {
streamObserver.notify(ollamaResponseModel);
}
return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!", e);
return true;
}
}
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
return super.callSync(body);
throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
return super.callSync(body);
}
}

View File

@ -0,0 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import java.util.Map;
@FunctionalInterface
public interface DynamicFunction {
Object apply(Map<String, Object> arguments);
}

View File

@ -0,0 +1,139 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import lombok.Builder;
import lombok.Data;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class MistralTools {
@Data
@Builder
public static class ToolSpecification {
private String functionName;
private String functionDesc;
private Map<String, PromptFuncDefinition.Property> props;
private DynamicFunction toolDefinition;
}
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public static class PromptFuncDefinition {
private String type;
private PromptFuncSpec function;
@Data
public static class PromptFuncSpec {
private String name;
private String description;
private Parameters parameters;
}
@Data
public static class Parameters {
private String type;
private Map<String, Property> properties;
private List<String> required;
}
@Data
@Builder
public static class Property {
private String type;
private String description;
@JsonProperty("enum")
@JsonInclude(JsonInclude.Include.NON_NULL)
private List<String> enumValues;
@JsonIgnore
private boolean required;
}
}
public static class PropsBuilder {
private final Map<String, PromptFuncDefinition.Property> props = new HashMap<>();
public PropsBuilder withProperty(String key, PromptFuncDefinition.Property property) {
props.put(key, property);
return this;
}
public Map<String, PromptFuncDefinition.Property> build() {
return props;
}
}
public static class PromptBuilder {
private final List<PromptFuncDefinition> tools = new ArrayList<>();
private String promptText;
public String build() throws JsonProcessingException {
return "[AVAILABLE_TOOLS] " + Utils.getObjectMapper().writeValueAsString(tools) + "[/AVAILABLE_TOOLS][INST] " + promptText + " [/INST]";
}
public PromptBuilder withPrompt(String prompt) throws JsonProcessingException {
promptText = prompt;
return this;
}
public PromptBuilder withToolSpecification(ToolSpecification spec) {
PromptFuncDefinition def = new PromptFuncDefinition();
def.setType("function");
PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
functionDetail.setName(spec.getFunctionName());
functionDetail.setDescription(spec.getFunctionDesc());
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
parameters.setType("object");
parameters.setProperties(spec.getProps());
List<String> requiredValues = new ArrayList<>();
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProps().entrySet()) {
if (p.getValue().isRequired()) {
requiredValues.add(p.getKey());
}
}
parameters.setRequired(requiredValues);
functionDetail.setParameters(parameters);
def.setFunction(functionDetail);
tools.add(def);
return this;
}
//
// public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map<String, PromptFuncDefinition.Property> props) {
// PromptFuncDefinition def = new PromptFuncDefinition();
// def.setType("function");
//
// PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
// functionDetail.setName(functionName);
// functionDetail.setDescription(functionDesc);
//
// PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
// parameters.setType("object");
// parameters.setProperties(props);
//
// List<String> requiredValues = new ArrayList<>();
// for (Map.Entry<String, PromptFuncDefinition.Property> p : props.entrySet()) {
// if (p.getValue().isRequired()) {
// requiredValues.add(p.getKey());
// }
// }
// parameters.setRequired(requiredValues);
// functionDetail.setParameters(parameters);
// def.setFunction(functionDetail);
//
// tools.add(def);
// return this;
// }
}
}

View File

@ -0,0 +1,16 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Map;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class OllamaToolsResult {
private OllamaResult modelResult;
private Map<ToolDef, Object> toolResults;
}

View File

@ -0,0 +1,18 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Map;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ToolDef {
private String name;
private Map<String, Object> arguments;
}

View File

@ -0,0 +1,17 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import java.util.HashMap;
import java.util.Map;
public class ToolRegistry {
private static final Map<String, DynamicFunction> functionMap = new HashMap<>();
public static DynamicFunction getFunction(String name) {
return functionMap.get(name);
}
public static void addFunction(String name, DynamicFunction function) {
functionMap.put(name, function);
}
}

View File

@ -8,57 +8,81 @@ package io.github.amithkoujalgi.ollama4j.core.types;
*/
@SuppressWarnings("ALL")
public class OllamaModelType {
public static final String LLAMA2 = "llama2";
public static final String MISTRAL = "mistral";
public static final String LLAVA = "llava";
public static final String MIXTRAL = "mixtral";
public static final String STARLING_LM = "starling-lm";
public static final String NEURAL_CHAT = "neural-chat";
public static final String CODELLAMA = "codellama";
public static final String LLAMA2_UNCENSORED = "llama2-uncensored";
public static final String DOLPHIN_MIXTRAL = "dolphin-mixtral";
public static final String ORCA_MINI = "orca-mini";
public static final String VICUNA = "vicuna";
public static final String WIZARD_VICUNA_UNCENSORED = "wizard-vicuna-uncensored";
public static final String PHIND_CODELLAMA = "phind-codellama";
public static final String PHI = "phi";
public static final String ZEPHYR = "zephyr";
public static final String WIZARDCODER = "wizardcoder";
public static final String MISTRAL_OPENORCA = "mistral-openorca";
public static final String NOUS_HERMES = "nous-hermes";
public static final String DEEPSEEK_CODER = "deepseek-coder";
public static final String WIZARD_MATH = "wizard-math";
public static final String LLAMA2_CHINESE = "llama2-chinese";
public static final String FALCON = "falcon";
public static final String ORCA2 = "orca2";
public static final String STABLE_BELUGA = "stable-beluga";
public static final String CODEUP = "codeup";
public static final String EVERYTHINGLM = "everythinglm";
public static final String MEDLLAMA2 = "medllama2";
public static final String WIZARDLM_UNCENSORED = "wizardlm-uncensored";
public static final String STARCODER = "starcoder";
public static final String DOLPHIN22_MISTRAL = "dolphin2.2-mistral";
public static final String OPENCHAT = "openchat";
public static final String WIZARD_VICUNA = "wizard-vicuna";
public static final String OPENHERMES25_MISTRAL = "openhermes2.5-mistral";
public static final String OPEN_ORCA_PLATYPUS2 = "open-orca-platypus2";
public static final String YI = "yi";
public static final String YARN_MISTRAL = "yarn-mistral";
public static final String SAMANTHA_MISTRAL = "samantha-mistral";
public static final String SQLCODER = "sqlcoder";
public static final String YARN_LLAMA2 = "yarn-llama2";
public static final String MEDITRON = "meditron";
public static final String STABLELM_ZEPHYR = "stablelm-zephyr";
public static final String OPENHERMES2_MISTRAL = "openhermes2-mistral";
public static final String DEEPSEEK_LLM = "deepseek-llm";
public static final String MISTRALLITE = "mistrallite";
public static final String DOLPHIN21_MISTRAL = "dolphin2.1-mistral";
public static final String WIZARDLM = "wizardlm";
public static final String CODEBOOGA = "codebooga";
public static final String MAGICODER = "magicoder";
public static final String GOLIATH = "goliath";
public static final String NEXUSRAVEN = "nexusraven";
public static final String ALFRED = "alfred";
public static final String XWINLM = "xwinlm";
public static final String BAKLLAVA = "bakllava";
public static final String GEMMA = "gemma";
public static final String GEMMA2 = "gemma2";
public static final String LLAMA2 = "llama2";
public static final String LLAMA3 = "llama3";
public static final String MISTRAL = "mistral";
public static final String MIXTRAL = "mixtral";
public static final String LLAVA = "llava";
public static final String LLAVA_PHI3 = "llava-phi3";
public static final String NEURAL_CHAT = "neural-chat";
public static final String CODELLAMA = "codellama";
public static final String DOLPHIN_MIXTRAL = "dolphin-mixtral";
public static final String MISTRAL_OPENORCA = "mistral-openorca";
public static final String LLAMA2_UNCENSORED = "llama2-uncensored";
public static final String PHI = "phi";
public static final String PHI3 = "phi3";
public static final String ORCA_MINI = "orca-mini";
public static final String DEEPSEEK_CODER = "deepseek-coder";
public static final String DOLPHIN_MISTRAL = "dolphin-mistral";
public static final String VICUNA = "vicuna";
public static final String WIZARD_VICUNA_UNCENSORED = "wizard-vicuna-uncensored";
public static final String ZEPHYR = "zephyr";
public static final String OPENHERMES = "openhermes";
public static final String QWEN = "qwen";
public static final String QWEN2 = "qwen2";
public static final String WIZARDCODER = "wizardcoder";
public static final String LLAMA2_CHINESE = "llama2-chinese";
public static final String TINYLLAMA = "tinyllama";
public static final String PHIND_CODELLAMA = "phind-codellama";
public static final String OPENCHAT = "openchat";
public static final String ORCA2 = "orca2";
public static final String FALCON = "falcon";
public static final String WIZARD_MATH = "wizard-math";
public static final String TINYDOLPHIN = "tinydolphin";
public static final String NOUS_HERMES = "nous-hermes";
public static final String YI = "yi";
public static final String DOLPHIN_PHI = "dolphin-phi";
public static final String STARLING_LM = "starling-lm";
public static final String STARCODER = "starcoder";
public static final String CODEUP = "codeup";
public static final String MEDLLAMA2 = "medllama2";
public static final String STABLE_CODE = "stable-code";
public static final String WIZARDLM_UNCENSORED = "wizardlm-uncensored";
public static final String BAKLLAVA = "bakllava";
public static final String EVERYTHINGLM = "everythinglm";
public static final String SOLAR = "solar";
public static final String STABLE_BELUGA = "stable-beluga";
public static final String SQLCODER = "sqlcoder";
public static final String YARN_MISTRAL = "yarn-mistral";
public static final String NOUS_HERMES2_MIXTRAL = "nous-hermes2-mixtral";
public static final String SAMANTHA_MISTRAL = "samantha-mistral";
public static final String STABLELM_ZEPHYR = "stablelm-zephyr";
public static final String MEDITRON = "meditron";
public static final String WIZARD_VICUNA = "wizard-vicuna";
public static final String STABLELM2 = "stablelm2";
public static final String MAGICODER = "magicoder";
public static final String YARN_LLAMA2 = "yarn-llama2";
public static final String NOUS_HERMES2 = "nous-hermes2";
public static final String DEEPSEEK_LLM = "deepseek-llm";
public static final String LLAMA_PRO = "llama-pro";
public static final String OPEN_ORCA_PLATYPUS2 = "open-orca-platypus2";
public static final String CODEBOOGA = "codebooga";
public static final String MISTRALLITE = "mistrallite";
public static final String NEXUSRAVEN = "nexusraven";
public static final String GOLIATH = "goliath";
public static final String NOMIC_EMBED_TEXT = "nomic-embed-text";
public static final String NOTUX = "notux";
public static final String ALFRED = "alfred";
public static final String MEGADOLPHIN = "megadolphin";
public static final String WIZARDLM = "wizardlm";
public static final String XWINLM = "xwinlm";
public static final String NOTUS = "notus";
public static final String DUCKDB_NSQL = "duckdb-nsql";
public static final String ALL_MINILM = "all-minilm";
public static final String CODESTRAL = "codestral";
}

View File

@ -8,10 +8,18 @@ import java.net.URISyntaxException;
import java.net.URL;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
public class Utils {
private static ObjectMapper objectMapper;
public static ObjectMapper getObjectMapper() {
return new ObjectMapper();
if(objectMapper == null) {
objectMapper = new ObjectMapper();
objectMapper.registerModule(new JavaTimeModule());
}
return objectMapper;
}
public static byte[] loadImageBytesFromUrl(String imageUrl)

View File

@ -1,7 +1,5 @@
package io.github.amithkoujalgi.ollama4j.integrationtests;
import static org.junit.jupiter.api.Assertions.*;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
@ -10,7 +8,16 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import lombok.Data;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
@ -20,355 +27,369 @@ import java.net.http.HttpConnectTimeoutException;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import lombok.Data;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.junit.jupiter.api.Assertions.*;
class TestRealAPIs {
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
OllamaAPI ollamaAPI;
Config config;
OllamaAPI ollamaAPI;
Config config;
private File getImageFileFromClasspath(String fileName) {
ClassLoader classLoader = getClass().getClassLoader();
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
}
@BeforeEach
void setUp() {
config = new Config();
ollamaAPI = new OllamaAPI(config.getOllamaURL());
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
}
@Test
@Order(1)
void testWrongEndpoint() {
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
assertThrows(ConnectException.class, ollamaAPI::listModels);
}
@Test
@Order(1)
void testEndpointReachability() {
try {
assertNotNull(ollamaAPI.listModels());
} catch (HttpConnectTimeoutException e) {
fail(e.getMessage());
} catch (Exception e) {
throw new RuntimeException(e);
private File getImageFileFromClasspath(String fileName) {
ClassLoader classLoader = getClass().getClassLoader();
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
}
}
@Test
@Order(2)
void testListModels() {
testEndpointReachability();
try {
assertNotNull(ollamaAPI.listModels());
ollamaAPI.listModels().forEach(System.out::println);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@BeforeEach
void setUp() {
config = new Config();
ollamaAPI = new OllamaAPI(config.getOllamaURL());
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
}
}
@Test
@Order(2)
void testPullModel() {
testEndpointReachability();
try {
ollamaAPI.pullModel(config.getModel());
boolean found =
ollamaAPI.listModels().stream()
.anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
assertTrue(found);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
@Order(1)
void testWrongEndpoint() {
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
assertThrows(ConnectException.class, ollamaAPI::listModels);
}
}
@Test
@Order(3)
void testListDtails() {
testEndpointReachability();
try {
ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel());
assertNotNull(modelDetails);
System.out.println(modelDetails);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
@Order(1)
void testEndpointReachability() {
try {
assertNotNull(ollamaAPI.listModels());
} catch (HttpConnectTimeoutException e) {
fail(e.getMessage());
} catch (Exception e) {
fail(e);
}
}
}
@Test
@Order(3)
void testAskModelWithDefaultOptions() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generate(
config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
@Test
@Order(2)
void testListModels() {
testEndpointReachability();
try {
assertNotNull(ollamaAPI.listModels());
ollamaAPI.listModels().forEach(System.out::println);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testAskModelWithDefaultOptionsStreamed() {
testEndpointReachability();
try {
StringBuffer sb = new StringBuffer("");
OllamaResult result = ollamaAPI.generate(config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
@Test
@Order(2)
void testPullModel() {
testEndpointReachability();
try {
ollamaAPI.pullModel(config.getModel());
boolean found =
ollamaAPI.listModels().stream()
.anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
assertTrue(found);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testAskModelWithOptions() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generate(
config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
new OptionsBuilder().setTemperature(0.9f).build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
@Test
@Order(3)
void testListDtails() {
testEndpointReachability();
try {
ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel());
assertNotNull(modelDetails);
System.out.println(modelDetails);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testChat() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
.withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
.withMessage(OllamaChatMessageRole.USER,"And what is the second larges city?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank());
assertEquals(4,chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
@Test
@Order(3)
void testAskModelWithDefaultOptions() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generate(
config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
false,
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testChatWithSystemPrompt() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
@Test
@Order(3)
void testAskModelWithDefaultOptionsStreamed() {
testEndpointReachability();
try {
StringBuffer sb = new StringBuffer("");
OllamaResult result = ollamaAPI.generate(config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
false,
new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank());
assertTrue(chatResult.getResponse().startsWith("NI"));
assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testChatWithStream() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
StringBuffer sb = new StringBuffer("");
OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
@Test
@Order(3)
void testAskModelWithOptions() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generate(
config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
true,
new OptionsBuilder().setTemperature(0.9f).build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testChatWithImageFromFileWithHistoryRecognition() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder =
OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequestModel requestModel =
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
@Test
@Order(3)
void testChat() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
.withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
.withMessage(OllamaChatMessageRole.USER, "And what is the second larges city?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponse());
builder.reset();
requestModel =
builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponse());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank());
assertEquals(4, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testChatWithImageFromURL() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build();
@Test
@Order(3)
void testChatWithSystemPrompt() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank());
assertTrue(chatResult.getResponse().startsWith("NI"));
assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageFiles() {
testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
try {
OllamaResult result =
ollamaAPI.generateWithImageFiles(
config.getImageModel(),
"What is in this image?",
List.of(imageFile),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
@Test
@Order(3)
void testChatWithStream() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
StringBuffer sb = new StringBuffer("");
OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageFilesStreamed() {
testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
try {
StringBuffer sb = new StringBuffer("");
@Test
@Order(3)
void testChatWithImageFromFileWithHistoryRecognition() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder =
OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequestModel requestModel =
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
"What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponse());
builder.reset();
requestModel =
builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponse());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageURLs() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generateWithImageURLs(
config.getImageModel(),
"What is in this image?",
List.of(
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
@Order(3)
void testChatWithImageFromURL() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageFiles() {
testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
try {
OllamaResult result =
ollamaAPI.generateWithImageFiles(
config.getImageModel(),
"What is in this image?",
List.of(imageFile),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageFilesStreamed() {
testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
try {
StringBuffer sb = new StringBuffer("");
OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
"What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageURLs() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generateWithImageURLs(
config.getImageModel(),
"What is in this image?",
List.of(
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
}
}
@Test
@Order(3)
public void testEmbedding() {
testEndpointReachability();
try {
OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
.getInstance(config.getModel(), "What is the capital of France?").build();
List<Double> embeddings = ollamaAPI.generateEmbeddings(request);
assertNotNull(embeddings);
assertFalse(embeddings.isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
}
@Data
class Config {
private String ollamaURL;
private String model;
private String imageModel;
private int requestTimeoutSeconds;
private String ollamaURL;
private String model;
private String imageModel;
private int requestTimeoutSeconds;
public Config() {
Properties properties = new Properties();
try (InputStream input =
getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
if (input == null) {
throw new RuntimeException("Sorry, unable to find test-config.properties");
}
properties.load(input);
this.ollamaURL = properties.getProperty("ollama.url");
this.model = properties.getProperty("ollama.model");
this.imageModel = properties.getProperty("ollama.model.image");
this.requestTimeoutSeconds =
Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
} catch (IOException e) {
throw new RuntimeException("Error loading properties", e);
public Config() {
Properties properties = new Properties();
try (InputStream input =
getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
if (input == null) {
throw new RuntimeException("Sorry, unable to find test-config.properties");
}
properties.load(input);
this.ollamaURL = properties.getProperty("ollama.url");
this.model = properties.getProperty("ollama.model");
this.imageModel = properties.getProperty("ollama.model.image");
this.requestTimeoutSeconds =
Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
} catch (IOException e) {
throw new RuntimeException("Error loading properties", e);
}
}
}
}

View File

@ -1,7 +1,5 @@
package io.github.amithkoujalgi.ollama4j.unittests;
import static org.mockito.Mockito.*;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
@ -9,155 +7,158 @@ import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collections;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import static org.mockito.Mockito.*;
class TestMockedAPIs {
@Test
void testPullModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
doNothing().when(ollamaAPI).pullModel(model);
ollamaAPI.pullModel(model);
verify(ollamaAPI, times(1)).pullModel(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
void testPullModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
doNothing().when(ollamaAPI).pullModel(model);
ollamaAPI.pullModel(model);
verify(ollamaAPI, times(1)).pullModel(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testListModels() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
try {
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
ollamaAPI.listModels();
verify(ollamaAPI, times(1)).listModels();
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
void testListModels() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
try {
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
ollamaAPI.listModels();
verify(ollamaAPI, times(1)).listModels();
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testCreateModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
try {
doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
void testCreateModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
try {
doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testDeleteModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
doNothing().when(ollamaAPI).deleteModel(model, true);
ollamaAPI.deleteModel(model, true);
verify(ollamaAPI, times(1)).deleteModel(model, true);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
void testDeleteModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
doNothing().when(ollamaAPI).deleteModel(model, true);
ollamaAPI.deleteModel(model, true);
verify(ollamaAPI, times(1)).deleteModel(model, true);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testGetModelDetails() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
ollamaAPI.getModelDetails(model);
verify(ollamaAPI, times(1)).getModelDetails(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
void testGetModelDetails() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
ollamaAPI.getModelDetails(model);
verify(ollamaAPI, times(1)).getModelDetails(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testGenerateEmbeddings() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
ollamaAPI.generateEmbeddings(model, prompt);
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
@Test
void testGenerateEmbeddings() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
ollamaAPI.generateEmbeddings(model, prompt);
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
OptionsBuilder optionsBuilder = new OptionsBuilder();
try {
when(ollamaAPI.generate(model, prompt, optionsBuilder.build()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generate(model, prompt, optionsBuilder.build());
verify(ollamaAPI, times(1)).generate(model, prompt, optionsBuilder.build());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
@Test
void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
OptionsBuilder optionsBuilder = new OptionsBuilder();
try {
when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generate(model, prompt, false, optionsBuilder.build());
verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testAskWithImageFiles() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1))
.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
@Test
void testAskWithImageFiles() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1))
.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testAskWithImageURLs() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1))
.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
void testAskWithImageURLs() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1))
.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testAskAsync() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
when(ollamaAPI.generateAsync(model, prompt))
.thenReturn(new OllamaAsyncResultCallback(null, null, 3));
ollamaAPI.generateAsync(model, prompt);
verify(ollamaAPI, times(1)).generateAsync(model, prompt);
}
@Test
void testAskAsync() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
when(ollamaAPI.generateAsync(model, prompt, false))
.thenReturn(new OllamaAsyncResultCallback(null, null, 3));
ollamaAPI.generateAsync(model, prompt, false);
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
}
}

View File

@ -0,0 +1,35 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public abstract class AbstractSerializationTest<T> {
protected ObjectMapper mapper = Utils.getObjectMapper();
protected String serialize(T obj) {
try {
return mapper.writeValueAsString(obj);
} catch (JsonProcessingException e) {
fail("Could not serialize request!", e);
return null;
}
}
protected T deserialize(String jsonObject, Class<T> deserializationClass) {
try {
return mapper.readValue(jsonObject, deserializationClass);
} catch (JsonProcessingException e) {
fail("Could not deserialize jsonObject!", e);
return null;
}
}
protected void assertEqualsAfterUnmarshalling(T unmarshalledObject,
T req) {
assertEquals(req, unmarshalledObject);
}
}

View File

@ -1,7 +1,6 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import java.io.File;
import java.util.List;
@ -10,21 +9,15 @@ import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class TestChatRequestSerialization {
public class TestChatRequestSerialization extends AbstractSerializationTest<OllamaChatRequestModel> {
private OllamaChatRequestBuilder builder;
private ObjectMapper mapper = Utils.getObjectMapper();
@BeforeEach
public void init() {
builder = OllamaChatRequestBuilder.getInstance("DummyModel");
@ -32,10 +25,9 @@ public class TestChatRequestSerialization {
@Test
public void testRequestOnlyMandatoryFields() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt").build();
String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaChatRequestModel.class), req);
}
@Test
@ -43,28 +35,43 @@ public class TestChatRequestSerialization {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt")
.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaChatRequestModel.class), req);
}
@Test
public void testRequestWithMessageAndImage() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaChatRequestModel.class), req);
}
@Test
public void testRequestWithOptions() {
OptionsBuilder b = new OptionsBuilder();
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.withOptions(b.setMirostat(1).build()).build();
.withOptions(b.setMirostat(1).build())
.withOptions(b.setTemperature(1L).build())
.withOptions(b.setMirostatEta(1L).build())
.withOptions(b.setMirostatTau(1L).build())
.withOptions(b.setNumGpu(1).build())
.withOptions(b.setSeed(1).build())
.withOptions(b.setTopK(1).build())
.withOptions(b.setTopP(1).build())
.build();
String jsonRequest = serializeRequest(req);
OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest);
String jsonRequest = serialize(req);
OllamaChatRequestModel deserializeRequest = deserialize(jsonRequest, OllamaChatRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
assertEquals(1.0, deserializeRequest.getOptions().get("temperature"));
assertEquals(1.0, deserializeRequest.getOptions().get("mirostat_eta"));
assertEquals(1.0, deserializeRequest.getOptions().get("mirostat_tau"));
assertEquals(1, deserializeRequest.getOptions().get("num_gpu"));
assertEquals(1, deserializeRequest.getOptions().get("seed"));
assertEquals(1, deserializeRequest.getOptions().get("top_k"));
assertEquals(1.0, deserializeRequest.getOptions().get("top_p"));
}
@Test
@ -72,7 +79,7 @@ public class TestChatRequestSerialization {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.withGetJsonResponse().build();
String jsonRequest = serializeRequest(req);
String jsonRequest = serialize(req);
// no jackson deserialization as format property is not boolean ==> omit as deserialization
// of request is never used in real code anyways
JSONObject jsonObject = new JSONObject(jsonRequest);
@ -80,27 +87,27 @@ public class TestChatRequestSerialization {
assertEquals("json", requestFormatProperty);
}
private String serializeRequest(OllamaChatRequestModel req) {
try {
return mapper.writeValueAsString(req);
} catch (JsonProcessingException e) {
fail("Could not serialize request!", e);
return null;
}
@Test
public void testWithTemplate() {
OllamaChatRequestModel req = builder.withTemplate("System Template")
.build();
String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequestModel.class), req);
}
private OllamaChatRequestModel deserializeRequest(String jsonRequest) {
try {
return mapper.readValue(jsonRequest, OllamaChatRequestModel.class);
} catch (JsonProcessingException e) {
fail("Could not deserialize jsonRequest!", e);
return null;
}
@Test
public void testWithStreaming() {
OllamaChatRequestModel req = builder.withStreaming().build();
String jsonRequest = serialize(req);
assertEquals(deserialize(jsonRequest, OllamaChatRequestModel.class).isStream(), true);
}
private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest,
OllamaChatRequestModel req) {
assertEquals(req, unmarshalledRequest);
@Test
public void testWithKeepAlive() {
String expectedKeepAlive = "5m";
OllamaChatRequestModel req = builder.withKeepAlive(expectedKeepAlive)
.build();
String jsonRequest = serialize(req);
assertEquals(deserialize(jsonRequest, OllamaChatRequestModel.class).getKeepAlive(), expectedKeepAlive);
}
}

View File

@ -0,0 +1,37 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
public class TestEmbeddingsRequestSerialization extends AbstractSerializationTest<OllamaEmbeddingsRequestModel> {
private OllamaEmbeddingsRequestBuilder builder;
@BeforeEach
public void init() {
builder = OllamaEmbeddingsRequestBuilder.getInstance("DummyModel","DummyPrompt");
}
@Test
public void testRequestOnlyMandatoryFields() {
OllamaEmbeddingsRequestModel req = builder.build();
String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaEmbeddingsRequestModel.class), req);
}
@Test
public void testRequestWithOptions() {
OptionsBuilder b = new OptionsBuilder();
OllamaEmbeddingsRequestModel req = builder
.withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serialize(req);
OllamaEmbeddingsRequestModel deserializeRequest = deserialize(jsonRequest,OllamaEmbeddingsRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
}
}

View File

@ -1,26 +1,20 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public class TestGenerateRequestSerialization {
public class TestGenerateRequestSerialization extends AbstractSerializationTest<OllamaGenerateRequestModel> {
private OllamaGenerateRequestBuilder builder;
private ObjectMapper mapper = Utils.getObjectMapper();
@BeforeEach
public void init() {
builder = OllamaGenerateRequestBuilder.getInstance("DummyModel");
@ -30,8 +24,8 @@ public class TestGenerateRequestSerialization {
public void testRequestOnlyMandatoryFields() {
OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build();
String jsonRequest = serializeRequest(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaGenerateRequestModel.class), req);
}
@Test
@ -40,8 +34,8 @@ public class TestGenerateRequestSerialization {
OllamaGenerateRequestModel req =
builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req);
OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest);
String jsonRequest = serialize(req);
OllamaGenerateRequestModel deserializeRequest = deserialize(jsonRequest, OllamaGenerateRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
}
@ -51,7 +45,7 @@ public class TestGenerateRequestSerialization {
OllamaGenerateRequestModel req =
builder.withPrompt("Some prompt").withGetJsonResponse().build();
String jsonRequest = serializeRequest(req);
String jsonRequest = serialize(req);
// no jackson deserialization as format property is not boolean ==> omit as deserialization
// of request is never used in real code anyways
JSONObject jsonObject = new JSONObject(jsonRequest);
@ -59,27 +53,4 @@ public class TestGenerateRequestSerialization {
assertEquals("json", requestFormatProperty);
}
private String serializeRequest(OllamaGenerateRequestModel req) {
try {
return mapper.writeValueAsString(req);
} catch (JsonProcessingException e) {
fail("Could not serialize request!", e);
return null;
}
}
private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) {
try {
return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class);
} catch (JsonProcessingException e) {
fail("Could not deserialize jsonRequest!", e);
return null;
}
}
private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest,
OllamaGenerateRequestModel req) {
assertEquals(req, unmarshalledRequest);
}
}

View File

@ -0,0 +1,42 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import io.github.amithkoujalgi.ollama4j.core.models.Model;
import org.junit.jupiter.api.Test;
public class TestModelRequestSerialization extends AbstractSerializationTest<Model> {
@Test
public void testDeserializationOfModelResponseWithOffsetTime(){
String serializedTestStringWithOffsetTime = "{\n"
+ "\"name\": \"codellama:13b\",\n"
+ "\"modified_at\": \"2023-11-04T14:56:49.277302595-07:00\",\n"
+ "\"size\": 7365960935,\n"
+ "\"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n"
+ "\"details\": {\n"
+ "\"format\": \"gguf\",\n"
+ "\"family\": \"llama\",\n"
+ "\"families\": null,\n"
+ "\"parameter_size\": \"13B\",\n"
+ "\"quantization_level\": \"Q4_0\"\n"
+ "}}";
deserialize(serializedTestStringWithOffsetTime,Model.class);
}
@Test
public void testDeserializationOfModelResponseWithZuluTime(){
String serializedTestStringWithZuluTimezone = "{\n"
+ "\"name\": \"codellama:13b\",\n"
+ "\"modified_at\": \"2023-11-04T14:56:49.277302595Z\",\n"
+ "\"size\": 7365960935,\n"
+ "\"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n"
+ "\"details\": {\n"
+ "\"format\": \"gguf\",\n"
+ "\"family\": \"llama\",\n"
+ "\"families\": null,\n"
+ "\"parameter_size\": \"13B\",\n"
+ "\"quantization_level\": \"Q4_0\"\n"
+ "}}";
deserialize(serializedTestStringWithZuluTimezone,Model.class);
}
}