mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-10-27 14:40:42 +01:00
Compare commits
57 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ab70201844 | ||
|
|
ac8a40a017 | ||
|
|
1ac65f821b | ||
|
|
d603c4b94b | ||
|
|
a418cbc1dc | ||
|
|
785dd12730 | ||
|
|
dda807d818 | ||
|
|
a06a4025fa | ||
| 761fbc3398 | |||
| a96dc11679 | |||
| b2b3febdaa | |||
|
|
f27bea11d5 | ||
|
|
9503451d5a | ||
|
|
04bae4ca6a | ||
|
|
3e33b8df62 | ||
|
|
a494053263 | ||
|
|
260c57ca84 | ||
|
|
db008de0ca | ||
|
|
1b38466f44 | ||
|
|
26ec00dab8 | ||
|
|
5e6971cc4a | ||
|
|
8b3417ecda | ||
|
|
35f5f34196 | ||
|
|
d8c3edd55f | ||
|
|
7ffbc5d3f2 | ||
|
|
c4b7830614 | ||
|
|
69f6fd81cf | ||
|
|
b6a293add7 | ||
|
|
25694a8bc9 | ||
|
|
12bb10392e | ||
|
|
e9c33ab0b2 | ||
|
|
903a8176cd | ||
|
|
4a91918e84 | ||
|
|
ff3344616c | ||
|
|
726fea5b74 | ||
|
|
a09f1362e9 | ||
|
|
4ef0821932 | ||
|
|
2d3cf228cb | ||
|
|
5b3713c69e | ||
|
|
e9486cbb8e | ||
|
|
057f0babeb | ||
|
|
da146640ca | ||
|
|
82be761b86 | ||
|
|
9c3fc49df1 | ||
|
|
5f19eb17ac | ||
|
|
ecb04d6d82 | ||
|
|
3fc7e9423c | ||
|
|
405a08b330 | ||
|
|
921f745435 | ||
|
|
bedfec6bf9 | ||
|
|
afa09e87a5 | ||
|
|
baf2320ea6 | ||
|
|
948a7444fb | ||
|
|
ec0eb8b469 | ||
|
|
8f33de7e59 | ||
|
|
8c59e6511b | ||
|
|
b93fc7623a |
29
.github/workflows/close-issue.yml
vendored
Normal file
29
.github/workflows/close-issue.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: Close inactive issues
|
||||
on:
|
||||
workflow_dispatch: # for manual run
|
||||
schedule:
|
||||
- cron: "0 0 * * *" # Runs daily at midnight
|
||||
|
||||
# Fine-grant permission
|
||||
# https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token
|
||||
permissions:
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
with:
|
||||
exempt-issue-labels: "refactor,help wanted,good first issue,research,bug"
|
||||
days-before-issue-stale: 7 # Issues become stale after 7 days
|
||||
days-before-issue-close: 15 # Issues close 15 days after being marked stale
|
||||
stale-issue-label: "stale"
|
||||
close-issue-message: "This issue was closed because it has been inactive for 15 days since being marked as stale."
|
||||
days-before-pr-stale: -1 # PRs are not handled
|
||||
days-before-pr-close: -1 # PRs are not handled
|
||||
operations-per-run: 10000
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
28
.github/workflows/label-issue-stale.yml
vendored
Normal file
28
.github/workflows/label-issue-stale.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Mark stale issues and pull requests
|
||||
|
||||
on:
|
||||
workflow_dispatch: # for manual run
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs every day at midnight
|
||||
|
||||
permissions:
|
||||
contents: write # only for delete-branch option
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Mark stale issues and pull requests
|
||||
uses: actions/stale@v8
|
||||
with:
|
||||
repo-token: ${{ github.token }}
|
||||
days-before-stale: 15
|
||||
stale-issue-message: 'This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs.'
|
||||
stale-pr-message: 'This pull request has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs.'
|
||||
days-before-close: 7
|
||||
stale-issue-label: 'stale'
|
||||
exempt-issue-labels: 'pinned,security'
|
||||
stale-pr-label: 'stale'
|
||||
exempt-pr-labels: 'work-in-progress'
|
||||
4
Makefile
4
Makefile
@@ -11,9 +11,9 @@ doxygen:
|
||||
doxygen Doxyfile
|
||||
|
||||
list-releases:
|
||||
curl 'https://central.sonatype.com/api/internal/browse/component/versions?sortField=normalizedVersion&sortDirection=asc&page=0&size=12&filter=namespace%3Aio.github.amithkoujalgi%2Cname%3Aollama4j' \
|
||||
curl 'https://central.sonatype.com/api/internal/browse/component/versions?sortField=normalizedVersion&sortDirection=desc&page=0&size=20&filter=namespace%3Aio.github.ollama4j%2Cname%3Aollama4j' \
|
||||
--compressed \
|
||||
--silent | jq '.components[].version'
|
||||
--silent | jq -r '.components[].version'
|
||||
|
||||
build-docs:
|
||||
npm i --prefix docs && npm run build --prefix docs
|
||||
|
||||
36
README.md
36
README.md
@@ -9,7 +9,6 @@ A Java library (wrapper/binding) for Ollama server.
|
||||
|
||||
Find more details on the [website](https://ollama4j.github.io/ollama4j/).
|
||||
|
||||
|
||||

|
||||

|
||||

|
||||
@@ -154,7 +153,7 @@ In your Maven project, add this dependency:
|
||||
<dependency>
|
||||
<groupId>io.github.ollama4j</groupId>
|
||||
<artifactId>ollama4j</artifactId>
|
||||
<version>1.0.79</version>
|
||||
<version>1.0.89</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
@@ -210,7 +209,7 @@ In your Maven project, add this dependency:
|
||||
<dependency>
|
||||
<groupId>io.github.ollama4j</groupId>
|
||||
<artifactId>ollama4j</artifactId>
|
||||
<version>1.0.79</version>
|
||||
<version>1.0.89</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
@@ -268,27 +267,24 @@ make integration-tests
|
||||
Newer artifacts are published via GitHub Actions CI workflow when a new release is created from `main` branch.
|
||||
|
||||
## ⭐ Give us a Star!
|
||||
|
||||
If you like or are using this project to build your own, please give us a star. It's a free way to show your support.
|
||||
|
||||
## Who's using Ollama4j?
|
||||
|
||||
- `Datafaker`: a library to generate fake data
|
||||
- https://github.com/datafaker-net/datafaker-experimental/tree/main/ollama-api
|
||||
- `Vaadin Web UI`: UI-Tester for Interactions with Ollama via ollama4j
|
||||
- https://github.com/TEAMPB/ollama4j-vaadin-ui
|
||||
- `ollama-translator`: Minecraft 1.20.6 spigot plugin allows to easily break language barriers by using ollama on the
|
||||
server to translate all messages into a specfic target language.
|
||||
- https://github.com/liebki/ollama-translator
|
||||
- `Another Minecraft Mod`: https://www.reddit.com/r/fabricmc/comments/1e65x5s/comment/ldr2vcf/
|
||||
- `Ollama4j Web UI`: A web UI for Ollama written in Java using Spring Boot and Vaadin framework and
|
||||
Ollama4j.
|
||||
- https://github.com/ollama4j/ollama4j-web-ui
|
||||
- `JnsCLI`: A command-line tool for Jenkins that manages jobs, builds, and configurations directly from the terminal while offering AI-powered error analysis for quick troubleshooting.
|
||||
- https://github.com/mirum8/jnscli
|
||||
- `Katie Backend`: An Open Source AI-based question-answering platform that helps companies and organizations make their private domain knowledge accessible and useful to their employees and customers.
|
||||
- https://github.com/wyona/katie-backend
|
||||
- `TeleLlama3 Bot`: A Question-Answering Telegram Bot.
|
||||
- https://git.hiast.edu.sy/mohamadbashar.disoki/telellama3-bot
|
||||
| # | Project Name | Description | Link |
|
||||
|---|-------------------|---------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| 1 | Datafaker | A library to generate fake data | [GitHub](https://github.com/datafaker-net/datafaker-experimental/tree/main/ollama-api) |
|
||||
| 2 | Vaadin Web UI | UI-Tester for interactions with Ollama via ollama4j | [GitHub](https://github.com/TEAMPB/ollama4j-vaadin-ui) |
|
||||
| 3 | ollama-translator | A Minecraft 1.20.6 Spigot plugin that translates all messages into a specific target language via Ollama | [GitHub](https://github.com/liebki/ollama-translator) |
|
||||
| 4 | AI Player | A Minecraft mod that adds an intelligent "second player" to the game | [GitHub](https://github.com/shasankp000/AI-Player), <br/> [Reddit Thread](https://www.reddit.com/r/fabricmc/comments/1e65x5s/comment/ldr2vcf/) |
|
||||
| 5 | Ollama4j Web UI | A web UI for Ollama written in Java using Spring Boot, Vaadin, and Ollama4j | [GitHub](https://github.com/ollama4j/ollama4j-web-ui) |
|
||||
| 6 | JnsCLI | A command-line tool for Jenkins that manages jobs, builds, and configurations, with AI-powered error analysis | [GitHub](https://github.com/mirum8/jnscli) |
|
||||
| 7 | Katie Backend | An open-source AI-based question-answering platform for accessing private domain knowledge | [GitHub](https://github.com/wyona/katie-backend) |
|
||||
| 8 | TeleLlama3 Bot | A question-answering Telegram bot | [Repo](https://git.hiast.edu.sy/mohamadbashar.disoki/telellama3-bot) |
|
||||
| 9 | moqui-wechat | A moqui-wechat component | [GitHub](https://github.com/heguangyong/moqui-wechat) |
|
||||
| 10 | B4X | A set of simple and powerful RAD tool for Desktop and Server development | [Website](https://www.b4x.com/android/forum/threads/ollama4j-library-pnd_ollama4j-your-local-offline-llm-like-chatgpt.165003/) |
|
||||
|
||||
|
||||
## Traction
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ public class Main {
|
||||
// start conversation with model
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
|
||||
System.out.println("First answer: " + chatResult.getResponse());
|
||||
System.out.println("First answer: " + chatResult.getResponseModel().getMessage().getContent());
|
||||
|
||||
// create next userQuestion
|
||||
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is the second largest city?").build();
|
||||
@@ -41,7 +41,7 @@ public class Main {
|
||||
// "continue" conversation with model
|
||||
chatResult = ollamaAPI.chat(requestModel);
|
||||
|
||||
System.out.println("Second answer: " + chatResult.getResponse());
|
||||
System.out.println("Second answer: " + chatResult.getResponseModel().getMessage().getContent());
|
||||
|
||||
System.out.println("Chat History: " + chatResult.getChatHistory());
|
||||
}
|
||||
@@ -205,7 +205,7 @@ public class Main {
|
||||
// start conversation with model
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
|
||||
System.out.println(chatResult.getResponse());
|
||||
System.out.println(chatResult.getResponseModel());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,7 +244,7 @@ public class Main {
|
||||
new File("/path/to/image"))).build();
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
System.out.println("First answer: " + chatResult.getResponse());
|
||||
System.out.println("First answer: " + chatResult.getResponseModel());
|
||||
|
||||
builder.reset();
|
||||
|
||||
@@ -254,7 +254,7 @@ public class Main {
|
||||
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
|
||||
|
||||
chatResult = ollamaAPI.chat(requestModel);
|
||||
System.out.println("Second answer: " + chatResult.getResponse());
|
||||
System.out.println("Second answer: " + chatResult.getResponseModel());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
65
docs/docs/apis-generate/custom-roles.md
Normal file
65
docs/docs/apis-generate/custom-roles.md
Normal file
@@ -0,0 +1,65 @@
|
||||
---
|
||||
sidebar_position: 8
|
||||
---
|
||||
|
||||
# Custom Roles
|
||||
|
||||
Allows to manage custom roles (apart from the base roles) for chat interactions with the models.
|
||||
|
||||
_Particularly helpful when you would need to use different roles that the newer models support other than the base
|
||||
roles._
|
||||
|
||||
_Base roles are `SYSTEM`, `USER`, `ASSISTANT`, `TOOL`._
|
||||
|
||||
### Usage
|
||||
|
||||
#### Add new role
|
||||
|
||||
```java
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||
|
||||
public class Main {
|
||||
|
||||
public static void main(String[] args) {
|
||||
String host = "http://localhost:11434/";
|
||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
||||
|
||||
OllamaChatMessageRole customRole = ollamaAPI.addCustomRole("custom-role");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### List roles
|
||||
|
||||
```java
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||
|
||||
public class Main {
|
||||
|
||||
public static void main(String[] args) {
|
||||
String host = "http://localhost:11434/";
|
||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
||||
|
||||
List<OllamaChatMessageRole> roles = ollamaAPI.listRoles();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Get role
|
||||
|
||||
```java
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||
|
||||
public class Main {
|
||||
|
||||
public static void main(String[] args) {
|
||||
String host = "http://localhost:11434/";
|
||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
||||
|
||||
List<OllamaChatMessageRole> roles = ollamaAPI.getRole("custom-role");
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -35,7 +35,7 @@ public class Main {
|
||||
}
|
||||
```
|
||||
|
||||
Or, using the `OllamaEmbedResponseModel`:
|
||||
Or, using the `OllamaEmbedRequestModel`:
|
||||
|
||||
```java
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
|
||||
@@ -345,21 +345,291 @@ Rahul Kumar, Address: King St, Hyderabad, India, Phone: 9876543210}`
|
||||
|
||||
::::
|
||||
|
||||
### Potential Improvements
|
||||
### Using tools in Chat-API
|
||||
|
||||
Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool
|
||||
registration. For example:
|
||||
Instead of using the specific `ollamaAPI.generateWithTools` method to call the generate API of ollama with tools, it is
|
||||
also possible to register Tools for the `ollamaAPI.chat` methods. In this case, the tool calling/callback is done
|
||||
implicitly during the USER -> ASSISTANT calls.
|
||||
|
||||
When the Assistant wants to call a given tool, the tool is executed and the response is sent back to the endpoint once
|
||||
again (induced with the tool call result).
|
||||
|
||||
#### Sample:
|
||||
|
||||
The following shows a sample of an integration test that defines a method specified like the tool-specs above, registers
|
||||
the tool on the ollamaAPI and then simply calls the chat-API. All intermediate tool calling is wrapped inside the api
|
||||
call.
|
||||
|
||||
```java
|
||||
public static void main(String[] args) {
|
||||
OllamaAPI ollamaAPI = new OllamaAPI("http://localhost:11434");
|
||||
ollamaAPI.setVerbose(true);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("llama3.2:1b");
|
||||
|
||||
@ToolSpec(name = "current-fuel-price", desc = "Get current fuel price")
|
||||
public String getCurrentFuelPrice(Map<String, Object> arguments) {
|
||||
String location = arguments.get("location").toString();
|
||||
String fuelType = arguments.get("fuelType").toString();
|
||||
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
|
||||
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
||||
.functionName("get-employee-details")
|
||||
.functionDescription("Get employee details from the database")
|
||||
.toolPrompt(
|
||||
Tools.PromptFuncDefinition.builder().type("function").function(
|
||||
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||
.name("get-employee-details")
|
||||
.description("Get employee details from the database")
|
||||
.parameters(
|
||||
Tools.PromptFuncDefinition.Parameters.builder()
|
||||
.type("object")
|
||||
.properties(
|
||||
new Tools.PropsBuilder()
|
||||
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
|
||||
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
|
||||
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
|
||||
.build()
|
||||
)
|
||||
.required(List.of("employee-name"))
|
||||
.build()
|
||||
).build()
|
||||
).build()
|
||||
)
|
||||
.toolFunction(new DBQueryFunction())
|
||||
.build();
|
||||
|
||||
ollamaAPI.registerTool(databaseQueryToolSpecification);
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Give me the ID of the employee named 'Rahul Kumar'?")
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
}
|
||||
```
|
||||
|
||||
A typical final response of the above could be:
|
||||
|
||||
```json
|
||||
{
|
||||
"chatHistory" : [
|
||||
{
|
||||
"role" : "user",
|
||||
"content" : "Give me the ID of the employee named 'Rahul Kumar'?",
|
||||
"images" : null,
|
||||
"tool_calls" : [ ]
|
||||
}, {
|
||||
"role" : "assistant",
|
||||
"content" : "",
|
||||
"images" : null,
|
||||
"tool_calls" : [ {
|
||||
"function" : {
|
||||
"name" : "get-employee-details",
|
||||
"arguments" : {
|
||||
"employee-name" : "Rahul Kumar"
|
||||
}
|
||||
}
|
||||
} ]
|
||||
}, {
|
||||
"role" : "tool",
|
||||
"content" : "[TOOL_RESULTS]get-employee-details([employee-name]) : Employee Details {ID: b4bf186c-2ee1-44cc-8856-53b8b6a50f85, Name: Rahul Kumar, Address: null, Phone: null}[/TOOL_RESULTS]",
|
||||
"images" : null,
|
||||
"tool_calls" : null
|
||||
}, {
|
||||
"role" : "assistant",
|
||||
"content" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
|
||||
"images" : null,
|
||||
"tool_calls" : null
|
||||
} ],
|
||||
"responseModel" : {
|
||||
"model" : "llama3.2:1b",
|
||||
"message" : {
|
||||
"role" : "assistant",
|
||||
"content" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
|
||||
"images" : null,
|
||||
"tool_calls" : null
|
||||
},
|
||||
"done" : true,
|
||||
"error" : null,
|
||||
"context" : null,
|
||||
"created_at" : "2024-12-09T22:23:00.4940078Z",
|
||||
"done_reason" : "stop",
|
||||
"total_duration" : 2313709900,
|
||||
"load_duration" : 14494700,
|
||||
"prompt_eval_duration" : 772000000,
|
||||
"eval_duration" : 1188000000,
|
||||
"prompt_eval_count" : 166,
|
||||
"eval_count" : 41
|
||||
},
|
||||
"response" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
|
||||
"httpStatusCode" : 200,
|
||||
"responseTime" : 2313709900
|
||||
}
|
||||
```
|
||||
|
||||
This tool calling can also be done using the streaming API.
|
||||
|
||||
### Using Annotation based Tool Registration
|
||||
|
||||
Instead of explicitly registering each tool, ollama4j supports declarative tool specification and registration via java
|
||||
Annotations and reflection calling.
|
||||
|
||||
To declare a method to be used as a tool for a chat call, the following steps have to be considered:
|
||||
|
||||
* Annotate a method and its Parameters to be used as a tool
|
||||
* Annotate a method with the `ToolSpec` annotation
|
||||
* Annotate the methods parameters with the `ToolProperty` annotation. Only the following datatypes are supported for now:
|
||||
* `java.lang.String`
|
||||
* `java.lang.Integer`
|
||||
* `java.lang.Boolean`
|
||||
* `java.math.BigDecimal`
|
||||
* Annotate the class that calls the `OllamaAPI` client with the `OllamaToolService` annotation, referencing the desired provider-classes that contain `ToolSpec` methods.
|
||||
* Before calling the `OllamaAPI` chat request, call the method `OllamaAPI.registerAnnotatedTools()` method to add tools to the chat.
|
||||
|
||||
#### Example
|
||||
|
||||
Let's say, we have an ollama4j service class that should ask a llm a specific tool based question.
|
||||
|
||||
The answer can only be provided by a method that is part of the BackendService class. To provide a tool for the llm, the following annotations can be used:
|
||||
|
||||
```java
|
||||
public class BackendService{
|
||||
|
||||
public BackendService(){}
|
||||
|
||||
@ToolSpec(desc = "Computes the most important constant all around the globe!")
|
||||
public String computeMkeConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){
|
||||
return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The caller API can then be written as:
|
||||
```java
|
||||
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||
|
||||
@OllamaToolService(providers = BackendService.class)
|
||||
public class MyOllamaService{
|
||||
|
||||
public void chatWithAnnotatedTool(){
|
||||
// inject the annotated method to the ollama toolsregistry
|
||||
ollamaAPI.registerAnnotatedTools();
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Compute the most important constant in the world using 5 digits")
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
}
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
Or, if one needs to provide an object instance directly:
|
||||
```java
|
||||
public class MyOllamaService{
|
||||
|
||||
public void chatWithAnnotatedTool(){
|
||||
ollamaAPI.registerAnnotatedTools(new BackendService());
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Compute the most important constant in the world using 5 digits")
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
}
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
The request should be the following:
|
||||
|
||||
```json
|
||||
{
|
||||
"model" : "llama3.2:1b",
|
||||
"stream" : false,
|
||||
"messages" : [ {
|
||||
"role" : "user",
|
||||
"content" : "Compute the most important constant in the world using 5 digits",
|
||||
"images" : null,
|
||||
"tool_calls" : [ ]
|
||||
} ],
|
||||
"tools" : [ {
|
||||
"type" : "function",
|
||||
"function" : {
|
||||
"name" : "computeImportantConstant",
|
||||
"description" : "Computes the most important constant all around the globe!",
|
||||
"parameters" : {
|
||||
"type" : "object",
|
||||
"properties" : {
|
||||
"noOfDigits" : {
|
||||
"type" : "java.lang.Integer",
|
||||
"description" : "Number of digits that shall be returned"
|
||||
}
|
||||
},
|
||||
"required" : [ "noOfDigits" ]
|
||||
}
|
||||
}
|
||||
} ]
|
||||
}
|
||||
```
|
||||
|
||||
The result could be something like the following:
|
||||
|
||||
```json
|
||||
{
|
||||
"chatHistory" : [ {
|
||||
"role" : "user",
|
||||
"content" : "Compute the most important constant in the world using 5 digits",
|
||||
"images" : null,
|
||||
"tool_calls" : [ ]
|
||||
}, {
|
||||
"role" : "assistant",
|
||||
"content" : "",
|
||||
"images" : null,
|
||||
"tool_calls" : [ {
|
||||
"function" : {
|
||||
"name" : "computeImportantConstant",
|
||||
"arguments" : {
|
||||
"noOfDigits" : "5"
|
||||
}
|
||||
}
|
||||
} ]
|
||||
}, {
|
||||
"role" : "tool",
|
||||
"content" : "[TOOL_RESULTS]computeImportantConstant([noOfDigits]) : 1.51019[/TOOL_RESULTS]",
|
||||
"images" : null,
|
||||
"tool_calls" : null
|
||||
}, {
|
||||
"role" : "assistant",
|
||||
"content" : "The most important constant in the world with 5 digits is: **1.51019**",
|
||||
"images" : null,
|
||||
"tool_calls" : null
|
||||
} ],
|
||||
"responseModel" : {
|
||||
"model" : "llama3.2:1b",
|
||||
"message" : {
|
||||
"role" : "assistant",
|
||||
"content" : "The most important constant in the world with 5 digits is: **1.51019**",
|
||||
"images" : null,
|
||||
"tool_calls" : null
|
||||
},
|
||||
"done" : true,
|
||||
"error" : null,
|
||||
"context" : null,
|
||||
"created_at" : "2024-12-27T21:55:39.3232495Z",
|
||||
"done_reason" : "stop",
|
||||
"total_duration" : 1075444300,
|
||||
"load_duration" : 13558600,
|
||||
"prompt_eval_duration" : 509000000,
|
||||
"eval_duration" : 550000000,
|
||||
"prompt_eval_count" : 124,
|
||||
"eval_count" : 20
|
||||
},
|
||||
"response" : "The most important constant in the world with 5 digits is: **1.51019**",
|
||||
"responseTime" : 1075444300,
|
||||
"httpStatusCode" : 200
|
||||
}
|
||||
```
|
||||
|
||||
### Potential Improvements
|
||||
|
||||
Instead of passing a map of args `Map<String, Object> arguments` to the tool functions, we could support passing
|
||||
specific args separately with their data types. For example:
|
||||
|
||||
@@ -369,4 +639,4 @@ public String getCurrentFuelPrice(String location, String fuelType) {
|
||||
}
|
||||
```
|
||||
|
||||
Updating async/chat APIs with support for tool-based generation.
|
||||
Updating async/chat APIs with support for tool-based generation.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
sidebar_position: 4
|
||||
sidebar_position: 5
|
||||
---
|
||||
|
||||
# Create Model
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
sidebar_position: 5
|
||||
sidebar_position: 6
|
||||
---
|
||||
|
||||
# Delete Model
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
sidebar_position: 3
|
||||
sidebar_position: 4
|
||||
---
|
||||
|
||||
# Get Model Details
|
||||
|
||||
133
docs/docs/apis-model-management/list-library-models.md
Normal file
133
docs/docs/apis-model-management/list-library-models.md
Normal file
@@ -0,0 +1,133 @@
|
||||
---
|
||||
sidebar_position: 1
|
||||
---
|
||||
|
||||
# Models from Ollama Library
|
||||
|
||||
These API retrieves a list of models directly from the Ollama library.
|
||||
|
||||
### List Models from Ollama Library
|
||||
|
||||
This API fetches available models from the Ollama library page, including details such as the model's name, pull count,
|
||||
popular tags, tag count, and the last update time.
|
||||
|
||||
```java title="ListLibraryModels.java"
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.models.response.LibraryModel;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class Main {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
String host = "http://localhost:11434/";
|
||||
|
||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
||||
|
||||
List<LibraryModel> libraryModels = ollamaAPI.listModelsFromLibrary();
|
||||
|
||||
System.out.println(libraryModels);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The following is the sample output:
|
||||
|
||||
```
|
||||
[
|
||||
LibraryModel(name=llama3.2-vision, description=Llama 3.2 Vision is a collection of instruction-tuned image reasoning generative models in 11B and 90B sizes., pullCount=21.1K, totalTags=9, popularTags=[vision, 11b, 90b], lastUpdated=yesterday),
|
||||
LibraryModel(name=llama3.2, description=Meta's Llama 3.2 goes small with 1B and 3B models., pullCount=2.4M, totalTags=63, popularTags=[tools, 1b, 3b], lastUpdated=6 weeks ago)
|
||||
]
|
||||
```
|
||||
|
||||
### Get Tags of a Library Model
|
||||
|
||||
This API Fetches the tags associated with a specific model from Ollama library.
|
||||
|
||||
```java title="GetLibraryModelTags.java"
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.models.response.LibraryModel;
|
||||
import io.github.ollama4j.models.response.LibraryModelDetail;
|
||||
|
||||
public class Main {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
String host = "http://localhost:11434/";
|
||||
|
||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
||||
|
||||
List<LibraryModel> libraryModels = ollamaAPI.listModelsFromLibrary();
|
||||
|
||||
LibraryModelDetail libraryModelDetail = ollamaAPI.getLibraryModelDetails(libraryModels.get(0));
|
||||
|
||||
System.out.println(libraryModelDetail);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The following is the sample output:
|
||||
|
||||
```
|
||||
LibraryModelDetail(
|
||||
model=LibraryModel(name=llama3.2-vision, description=Llama 3.2 Vision is a collection of instruction-tuned image reasoning generative models in 11B and 90B sizes., pullCount=21.1K, totalTags=9, popularTags=[vision, 11b, 90b], lastUpdated=yesterday),
|
||||
tags=[
|
||||
LibraryModelTag(name=llama3.2-vision, tag=latest, size=7.9GB, lastUpdated=yesterday),
|
||||
LibraryModelTag(name=llama3.2-vision, tag=11b, size=7.9GB, lastUpdated=yesterday),
|
||||
LibraryModelTag(name=llama3.2-vision, tag=90b, size=55GB, lastUpdated=yesterday)
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Find a model from Ollama library
|
||||
|
||||
This API finds a specific model using model `name` and `tag` from Ollama library.
|
||||
|
||||
```java title="FindLibraryModel.java"
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.models.response.LibraryModelTag;
|
||||
|
||||
public class Main {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
String host = "http://localhost:11434/";
|
||||
|
||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
||||
|
||||
LibraryModelTag libraryModelTag = ollamaAPI.findModelTagFromLibrary("qwen2.5", "7b");
|
||||
|
||||
System.out.println(libraryModelTag);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The following is the sample output:
|
||||
|
||||
```
|
||||
LibraryModelTag(name=qwen2.5, tag=7b, size=4.7GB, lastUpdated=7 weeks ago)
|
||||
```
|
||||
|
||||
### Pull model using `LibraryModelTag`
|
||||
|
||||
You can use `LibraryModelTag` to pull models into Ollama server.
|
||||
|
||||
```java title="PullLibraryModelTags.java"
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.models.response.LibraryModelTag;
|
||||
|
||||
public class Main {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
String host = "http://localhost:11434/";
|
||||
|
||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
||||
|
||||
LibraryModelTag libraryModelTag = ollamaAPI.findModelTagFromLibrary("qwen2.5", "7b");
|
||||
|
||||
ollamaAPI.pullModel(libraryModelTag);
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -1,10 +1,10 @@
|
||||
---
|
||||
sidebar_position: 1
|
||||
sidebar_position: 2
|
||||
---
|
||||
|
||||
# List Models
|
||||
# List Local Models
|
||||
|
||||
This API lets you list available models on the Ollama server.
|
||||
This API lets you list downloaded/available models on the Ollama server.
|
||||
|
||||
```java title="ListModels.java"
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
sidebar_position: 2
|
||||
sidebar_position: 3
|
||||
---
|
||||
|
||||
# Pull Model
|
||||
@@ -23,4 +23,12 @@ public class Main {
|
||||
}
|
||||
```
|
||||
|
||||
Once downloaded, you can see them when you use [list models](./list-models) API.
|
||||
Once downloaded, you can see them when you use [list models](./list-models) API.
|
||||
|
||||
:::info
|
||||
|
||||
You can even pull models using Ollama model library APIs. This looks up the models directly on the Ollama model library page. Refer
|
||||
to [this](./list-library-models#pull-model-using-librarymodeltag).
|
||||
|
||||
:::
|
||||
|
||||
|
||||
9
pom.xml
9
pom.xml
@@ -63,6 +63,10 @@
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-javadoc-plugin</artifactId>
|
||||
<version>3.5.0</version>
|
||||
<configuration>
|
||||
<!-- to disable the "missing" warnings. Remove the doclint to enable warnings-->
|
||||
<doclint>all,-missing</doclint>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>attach-javadocs</id>
|
||||
@@ -136,6 +140,11 @@
|
||||
<version>${lombok.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.jsoup</groupId>
|
||||
<artifactId>jsoup</artifactId>
|
||||
<version>1.18.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
|
||||
@@ -1,29 +1,32 @@
|
||||
package io.github.ollama4j;
|
||||
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
import io.github.ollama4j.exceptions.RoleNotFoundException;
|
||||
import io.github.ollama4j.exceptions.ToolInvocationException;
|
||||
import io.github.ollama4j.exceptions.ToolNotFoundException;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
||||
import io.github.ollama4j.models.chat.OllamaChatResult;
|
||||
import io.github.ollama4j.models.chat.*;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
|
||||
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
|
||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||
import io.github.ollama4j.models.generate.OllamaTokenHandler;
|
||||
import io.github.ollama4j.models.ps.ModelsProcessResponse;
|
||||
import io.github.ollama4j.models.request.*;
|
||||
import io.github.ollama4j.models.response.*;
|
||||
import io.github.ollama4j.tools.*;
|
||||
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||
import io.github.ollama4j.tools.annotations.ToolProperty;
|
||||
import io.github.ollama4j.tools.annotations.ToolSpec;
|
||||
import io.github.ollama4j.utils.Options;
|
||||
import io.github.ollama4j.utils.Utils;
|
||||
import lombok.Setter;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.*;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Parameter;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.net.http.HttpClient;
|
||||
@@ -34,11 +37,19 @@ import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.time.Duration;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.jsoup.Jsoup;
|
||||
import org.jsoup.nodes.Document;
|
||||
import org.jsoup.nodes.Element;
|
||||
import org.jsoup.select.Elements;
|
||||
|
||||
/**
|
||||
* The base Ollama API class.
|
||||
*/
|
||||
@SuppressWarnings("DuplicatedCode")
|
||||
@SuppressWarnings({"DuplicatedCode", "resource"})
|
||||
public class OllamaAPI {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
|
||||
@@ -55,6 +66,10 @@ public class OllamaAPI {
|
||||
*/
|
||||
@Setter
|
||||
private boolean verbose = true;
|
||||
|
||||
@Setter
|
||||
private int maxChatToolCallRetries = 3;
|
||||
|
||||
private BasicAuth basicAuth;
|
||||
|
||||
private final ToolRegistry toolRegistry = new ToolRegistry();
|
||||
@@ -99,12 +114,7 @@ public class OllamaAPI {
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest httpRequest = null;
|
||||
try {
|
||||
httpRequest =
|
||||
getRequestBuilderDefault(new URI(url))
|
||||
.header("Accept", "application/json")
|
||||
.header("Content-type", "application/json")
|
||||
.GET()
|
||||
.build();
|
||||
httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
|
||||
} catch (URISyntaxException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
@@ -123,19 +133,17 @@ public class OllamaAPI {
|
||||
/**
|
||||
* Provides a list of running models and details about each model currently loaded into memory.
|
||||
*
|
||||
* @return ModelsProcessResponse
|
||||
* @return ModelsProcessResponse containing details about the running models
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
*/
|
||||
public ModelsProcessResponse ps() throws IOException, InterruptedException, OllamaBaseException {
|
||||
String url = this.host + "/api/ps";
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest httpRequest = null;
|
||||
try {
|
||||
httpRequest =
|
||||
getRequestBuilderDefault(new URI(url))
|
||||
.header("Accept", "application/json")
|
||||
.header("Content-type", "application/json")
|
||||
.GET()
|
||||
.build();
|
||||
httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
|
||||
} catch (URISyntaxException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
@@ -144,69 +152,182 @@ public class OllamaAPI {
|
||||
int statusCode = response.statusCode();
|
||||
String responseString = response.body();
|
||||
if (statusCode == 200) {
|
||||
return Utils.getObjectMapper()
|
||||
.readValue(responseString, ModelsProcessResponse.class);
|
||||
return Utils.getObjectMapper().readValue(responseString, ModelsProcessResponse.class);
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List available models from Ollama server.
|
||||
* Lists available models from the Ollama server.
|
||||
*
|
||||
* @return the list
|
||||
* @return a list of models available on the server
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws URISyntaxException if the URI for the request is malformed
|
||||
*/
|
||||
public List<Model> listModels()
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
String url = this.host + "/api/tags";
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest httpRequest =
|
||||
getRequestBuilderDefault(new URI(url))
|
||||
.header("Accept", "application/json")
|
||||
.header("Content-type", "application/json")
|
||||
.GET()
|
||||
.build();
|
||||
HttpResponse<String> response =
|
||||
httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
||||
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
|
||||
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
String responseString = response.body();
|
||||
if (statusCode == 200) {
|
||||
return Utils.getObjectMapper()
|
||||
.readValue(responseString, ListModelsResponse.class)
|
||||
.getModels();
|
||||
return Utils.getObjectMapper().readValue(responseString, ListModelsResponse.class).getModels();
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a list of models from the Ollama library. This method fetches the available models directly from Ollama
|
||||
* library page, including model details such as the name, pull count, popular tags, tag count, and the time when model was updated.
|
||||
*
|
||||
* @return A list of {@link LibraryModel} objects representing the models available in the Ollama library.
|
||||
* @throws OllamaBaseException If the HTTP request fails or the response is not successful (non-200 status code).
|
||||
* @throws IOException If an I/O error occurs during the HTTP request or response processing.
|
||||
* @throws InterruptedException If the thread executing the request is interrupted.
|
||||
* @throws URISyntaxException If there is an error creating the URI for the HTTP request.
|
||||
*/
|
||||
public List<LibraryModel> listModelsFromLibrary() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
String url = "https://ollama.com/library";
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
|
||||
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
String responseString = response.body();
|
||||
List<LibraryModel> models = new ArrayList<>();
|
||||
if (statusCode == 200) {
|
||||
Document doc = Jsoup.parse(responseString);
|
||||
Elements modelSections = doc.selectXpath("//*[@id='repo']/ul/li/a");
|
||||
for (Element e : modelSections) {
|
||||
LibraryModel model = new LibraryModel();
|
||||
Elements names = e.select("div > h2 > div > span");
|
||||
Elements desc = e.select("div > p");
|
||||
Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type");
|
||||
Elements popularTags = e.select("div > div > span");
|
||||
Elements totalTags = e.select("div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type");
|
||||
Elements lastUpdatedTime = e.select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)");
|
||||
|
||||
if (names.first() == null || names.isEmpty()) {
|
||||
// if name cannot be extracted, skip.
|
||||
continue;
|
||||
}
|
||||
Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName);
|
||||
model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse(""));
|
||||
model.setPopularTags(Optional.of(popularTags).map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())).orElse(new ArrayList<>()));
|
||||
model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse(""));
|
||||
model.setTotalTags(Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0));
|
||||
model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse(""));
|
||||
|
||||
models.add(model);
|
||||
}
|
||||
return models;
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches the tags associated with a specific model from Ollama library.
|
||||
* This method fetches the available model tags directly from Ollama library model page, including model tag name, size and time when model was last updated
|
||||
* into a list of {@link LibraryModelTag} objects.
|
||||
*
|
||||
* @param libraryModel the {@link LibraryModel} object which contains the name of the library model
|
||||
* for which the tags need to be fetched.
|
||||
* @return a list of {@link LibraryModelTag} objects containing the extracted tags and their associated metadata.
|
||||
* @throws OllamaBaseException if the HTTP response status code indicates an error (i.e., not 200 OK),
|
||||
* or if there is any other issue during the request or response processing.
|
||||
* @throws IOException if an input/output exception occurs during the HTTP request or response handling.
|
||||
* @throws InterruptedException if the thread is interrupted while waiting for the HTTP response.
|
||||
* @throws URISyntaxException if the URI format is incorrect or invalid.
|
||||
*/
|
||||
public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName());
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
|
||||
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
String responseString = response.body();
|
||||
|
||||
List<LibraryModelTag> libraryModelTags = new ArrayList<>();
|
||||
if (statusCode == 200) {
|
||||
Document doc = Jsoup.parse(responseString);
|
||||
Elements tagSections = doc.select("html > body > main > div > section > div > div > div:nth-child(n+2) > div");
|
||||
for (Element e : tagSections) {
|
||||
Elements tags = e.select("div > a > div");
|
||||
Elements tagsMetas = e.select("div > span");
|
||||
|
||||
LibraryModelTag libraryModelTag = new LibraryModelTag();
|
||||
|
||||
if (tags.first() == null || tags.isEmpty()) {
|
||||
// if tag cannot be extracted, skip.
|
||||
continue;
|
||||
}
|
||||
libraryModelTag.setName(libraryModel.getName());
|
||||
Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag);
|
||||
libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")).filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse(""));
|
||||
libraryModelTag.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")).filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse(""));
|
||||
libraryModelTags.add(libraryModelTag);
|
||||
}
|
||||
LibraryModelDetail libraryModelDetail = new LibraryModelDetail();
|
||||
libraryModelDetail.setModel(libraryModel);
|
||||
libraryModelDetail.setTags(libraryModelTags);
|
||||
return libraryModelDetail;
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds a specific model using model name and tag from Ollama library.
|
||||
* <p>
|
||||
* This method retrieves the model from the Ollama library by its name, then fetches its tags.
|
||||
* It searches through the tags of the model to find one that matches the specified tag name.
|
||||
* If the model or the tag is not found, it throws a {@link NoSuchElementException}.
|
||||
*
|
||||
* @param modelName The name of the model to search for in the library.
|
||||
* @param tag The tag name to search for within the specified model.
|
||||
* @return The {@link LibraryModelTag} associated with the specified model and tag.
|
||||
* @throws OllamaBaseException If there is a problem with the Ollama library operations.
|
||||
* @throws IOException If an I/O error occurs during the operation.
|
||||
* @throws URISyntaxException If there is an error with the URI syntax.
|
||||
* @throws InterruptedException If the operation is interrupted.
|
||||
* @throws NoSuchElementException If the model or the tag is not found.
|
||||
*/
|
||||
public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
List<LibraryModel> libraryModels = this.listModelsFromLibrary();
|
||||
LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName)));
|
||||
LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel);
|
||||
LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Tag '%s' for model '%s' not found", tag, modelName)));
|
||||
return libraryModelTag;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pull a model on the Ollama server from the list of <a
|
||||
* href="https://ollama.ai/library">available models</a>.
|
||||
*
|
||||
* @param modelName the name of the model
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws URISyntaxException if the URI for the request is malformed
|
||||
*/
|
||||
public void pullModel(String modelName)
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
public void pullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
String url = this.host + "/api/pull";
|
||||
String jsonData = new ModelRequest(modelName).toString();
|
||||
HttpRequest request =
|
||||
getRequestBuilderDefault(new URI(url))
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
||||
.header("Accept", "application/json")
|
||||
.header("Content-type", "application/json")
|
||||
.build();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json").header("Content-type", "application/json").build();
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<InputStream> response =
|
||||
client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
int statusCode = response.statusCode();
|
||||
InputStream responseBodyStream = response.body();
|
||||
String responseString = "";
|
||||
try (BufferedReader reader =
|
||||
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
ModelPullResponse modelPullResponse =
|
||||
Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
|
||||
ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
|
||||
if (verbose) {
|
||||
logger.info(modelPullResponse.getStatus());
|
||||
}
|
||||
@@ -217,22 +338,37 @@ public class OllamaAPI {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Pulls a model using the specified Ollama library model tag.
|
||||
* The model is identified by a name and a tag, which are combined into a single identifier
|
||||
* in the format "name:tag" to pull the corresponding model.
|
||||
*
|
||||
* @param libraryModelTag the {@link LibraryModelTag} object containing the name and tag
|
||||
* of the model to be pulled.
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws URISyntaxException if the URI for the request is malformed
|
||||
*/
|
||||
public void pullModel(LibraryModelTag libraryModelTag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
String tagToPull = String.format("%s:%s", libraryModelTag.getName(), libraryModelTag.getTag());
|
||||
pullModel(tagToPull);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets model details from the Ollama server.
|
||||
*
|
||||
* @param modelName the model
|
||||
* @return the model details
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws URISyntaxException if the URI for the request is malformed
|
||||
*/
|
||||
public ModelDetail getModelDetails(String modelName)
|
||||
throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
|
||||
public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
|
||||
String url = this.host + "/api/show";
|
||||
String jsonData = new ModelRequest(modelName).toString();
|
||||
HttpRequest request =
|
||||
getRequestBuilderDefault(new URI(url))
|
||||
.header("Accept", "application/json")
|
||||
.header("Content-type", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
||||
.build();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
@@ -250,17 +386,15 @@ public class OllamaAPI {
|
||||
*
|
||||
* @param modelName the name of the custom model to be created.
|
||||
* @param modelFilePath the path to model file that exists on the Ollama server.
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws URISyntaxException if the URI for the request is malformed
|
||||
*/
|
||||
public void createModelWithFilePath(String modelName, String modelFilePath)
|
||||
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
|
||||
public void createModelWithFilePath(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
|
||||
String url = this.host + "/api/create";
|
||||
String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString();
|
||||
HttpRequest request =
|
||||
getRequestBuilderDefault(new URI(url))
|
||||
.header("Accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
|
||||
.build();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
@@ -284,17 +418,15 @@ public class OllamaAPI {
|
||||
*
|
||||
* @param modelName the name of the custom model to be created.
|
||||
* @param modelFileContents the path to model file that exists on the Ollama server.
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws URISyntaxException if the URI for the request is malformed
|
||||
*/
|
||||
public void createModelWithModelFileContents(String modelName, String modelFileContents)
|
||||
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
|
||||
public void createModelWithModelFileContents(String modelName, String modelFileContents) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
|
||||
String url = this.host + "/api/create";
|
||||
String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString();
|
||||
HttpRequest request =
|
||||
getRequestBuilderDefault(new URI(url))
|
||||
.header("Accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
|
||||
.build();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
@@ -315,17 +447,15 @@ public class OllamaAPI {
|
||||
*
|
||||
* @param modelName the name of the model to be deleted.
|
||||
* @param ignoreIfNotPresent ignore errors if the specified model is not present on Ollama server.
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws URISyntaxException if the URI for the request is malformed
|
||||
*/
|
||||
public void deleteModel(String modelName, boolean ignoreIfNotPresent)
|
||||
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
|
||||
public void deleteModel(String modelName, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
|
||||
String url = this.host + "/api/delete";
|
||||
String jsonData = new ModelRequest(modelName).toString();
|
||||
HttpRequest request =
|
||||
getRequestBuilderDefault(new URI(url))
|
||||
.method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
|
||||
.header("Accept", "application/json")
|
||||
.header("Content-type", "application/json")
|
||||
.build();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build();
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
@@ -344,11 +474,13 @@ public class OllamaAPI {
|
||||
* @param model name of model to generate embeddings from
|
||||
* @param prompt text to generate embeddings for
|
||||
* @return embeddings
|
||||
* @deprecated Use {@link #embed(String, List<String>)} instead.
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @deprecated Use {@link #embed(String, List)} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public List<Double> generateEmbeddings(String model, String prompt)
|
||||
throws IOException, InterruptedException, OllamaBaseException {
|
||||
public List<Double> generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException {
|
||||
return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
|
||||
}
|
||||
|
||||
@@ -357,6 +489,9 @@ public class OllamaAPI {
|
||||
*
|
||||
* @param modelRequest request for '/api/embeddings' endpoint
|
||||
* @return embeddings
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @deprecated Use {@link #embed(OllamaEmbedRequestModel)} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
@@ -364,17 +499,13 @@ public class OllamaAPI {
|
||||
URI uri = URI.create(this.host + "/api/embeddings");
|
||||
String jsonData = modelRequest.toString();
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest.Builder requestBuilder =
|
||||
getRequestBuilderDefault(uri)
|
||||
.header("Accept", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData));
|
||||
HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData));
|
||||
HttpRequest request = requestBuilder.build();
|
||||
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
String responseBody = response.body();
|
||||
if (statusCode == 200) {
|
||||
OllamaEmbeddingResponseModel embeddingResponse =
|
||||
Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class);
|
||||
OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class);
|
||||
return embeddingResponse.getEmbedding();
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseBody);
|
||||
@@ -387,9 +518,11 @@ public class OllamaAPI {
|
||||
* @param model name of model to generate embeddings from
|
||||
* @param inputs text/s to generate embeddings for
|
||||
* @return embeddings
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaEmbedResponseModel embed(String model, List<String> inputs)
|
||||
throws IOException, InterruptedException, OllamaBaseException {
|
||||
public OllamaEmbedResponseModel embed(String model, List<String> inputs) throws IOException, InterruptedException, OllamaBaseException {
|
||||
return embed(new OllamaEmbedRequestModel(model, inputs));
|
||||
}
|
||||
|
||||
@@ -398,26 +531,23 @@ public class OllamaAPI {
|
||||
*
|
||||
* @param modelRequest request for '/api/embed' endpoint
|
||||
* @return embeddings
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest)
|
||||
throws IOException, InterruptedException, OllamaBaseException {
|
||||
public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException {
|
||||
URI uri = URI.create(this.host + "/api/embed");
|
||||
String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
|
||||
HttpRequest request = HttpRequest.newBuilder(uri)
|
||||
.header("Accept", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
||||
.build();
|
||||
HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
|
||||
|
||||
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
String responseBody = response.body();
|
||||
|
||||
if (statusCode == 200) {
|
||||
OllamaEmbedResponseModel embeddingResponse =
|
||||
Utils.getObjectMapper().readValue(responseBody, OllamaEmbedResponseModel.class);
|
||||
return embeddingResponse;
|
||||
return Utils.getObjectMapper().readValue(responseBody, OllamaEmbedResponseModel.class);
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseBody);
|
||||
}
|
||||
@@ -434,9 +564,11 @@ public class OllamaAPI {
|
||||
* details on the options</a>
|
||||
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
|
||||
* @return OllamaResult that includes response text and time taken for response
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
public OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
|
||||
ollamaRequestModel.setRaw(raw);
|
||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||
@@ -453,13 +585,14 @@ public class OllamaAPI {
|
||||
* @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context.
|
||||
* @param options Additional options or configurations to use when generating the response.
|
||||
* @return {@link OllamaResult}
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaResult generate(String model, String prompt, boolean raw, Options options)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
public OllamaResult generate(String model, String prompt, boolean raw, Options options) throws OllamaBaseException, IOException, InterruptedException {
|
||||
return generate(model, prompt, raw, options, null);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generates response using the specified AI model and prompt (in blocking mode), and then invokes a set of tools
|
||||
* on the generated response.
|
||||
@@ -468,17 +601,24 @@ public class OllamaAPI {
|
||||
* @param prompt The input text or prompt to provide to the AI model.
|
||||
* @param options Additional options or configurations to use when generating the response.
|
||||
* @return {@link OllamaToolsResult} An OllamaToolsResult object containing the response from the AI model and the results of invoking the tools on that output.
|
||||
* @throws OllamaBaseException If there is an error related to the Ollama API or service.
|
||||
* @throws IOException If there is an error related to input/output operations.
|
||||
* @throws InterruptedException If the method is interrupted while waiting for the AI model
|
||||
* to generate the response or for the tools to be invoked.
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaToolsResult generateWithTools(String model, String prompt, Options options)
|
||||
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
|
||||
public OllamaToolsResult generateWithTools(String model, String prompt, Options options) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
|
||||
boolean raw = true;
|
||||
OllamaToolsResult toolResult = new OllamaToolsResult();
|
||||
Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
|
||||
|
||||
if(!prompt.startsWith("[AVAILABLE_TOOLS]")){
|
||||
final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder();
|
||||
for(Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) {
|
||||
promptBuilder.withToolSpecification(spec);
|
||||
}
|
||||
promptBuilder.withPrompt(prompt);
|
||||
prompt = promptBuilder.build();
|
||||
}
|
||||
|
||||
OllamaResult result = generate(model, prompt, raw, options, null);
|
||||
toolResult.setModelResult(result);
|
||||
|
||||
@@ -495,7 +635,6 @@ public class OllamaAPI {
|
||||
return toolResult;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generate response for a question to a model running on Ollama server and get a callback handle
|
||||
* that can be used to check for status and get the response from the model later. This would be
|
||||
@@ -509,9 +648,7 @@ public class OllamaAPI {
|
||||
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
|
||||
ollamaRequestModel.setRaw(raw);
|
||||
URI uri = URI.create(this.host + "/api/generate");
|
||||
OllamaAsyncResultStreamer ollamaAsyncResultStreamer =
|
||||
new OllamaAsyncResultStreamer(
|
||||
getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
|
||||
OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
|
||||
ollamaAsyncResultStreamer.start();
|
||||
return ollamaAsyncResultStreamer;
|
||||
}
|
||||
@@ -528,10 +665,11 @@ public class OllamaAPI {
|
||||
* details on the options</a>
|
||||
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
|
||||
* @return OllamaResult that includes response text and time taken for response
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaResult generateWithImageFiles(
|
||||
String model, String prompt, List<File> imageFiles, Options options, OllamaStreamHandler streamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
List<String> images = new ArrayList<>();
|
||||
for (File imageFile : imageFiles) {
|
||||
images.add(encodeFileToBase64(imageFile));
|
||||
@@ -545,10 +683,12 @@ public class OllamaAPI {
|
||||
* Convenience method to call Ollama API without streaming responses.
|
||||
* <p>
|
||||
* Uses {@link #generateWithImageFiles(String, String, List, Options, OllamaStreamHandler)}
|
||||
*
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaResult generateWithImageFiles(
|
||||
String model, String prompt, List<File> imageFiles, Options options)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options) throws OllamaBaseException, IOException, InterruptedException {
|
||||
return generateWithImageFiles(model, prompt, imageFiles, options, null);
|
||||
}
|
||||
|
||||
@@ -564,10 +704,12 @@ public class OllamaAPI {
|
||||
* details on the options</a>
|
||||
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
|
||||
* @return OllamaResult that includes response text and time taken for response
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws URISyntaxException if the URI for the request is malformed
|
||||
*/
|
||||
public OllamaResult generateWithImageURLs(
|
||||
String model, String prompt, List<String> imageURLs, Options options, OllamaStreamHandler streamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
List<String> images = new ArrayList<>();
|
||||
for (String imageURL : imageURLs) {
|
||||
images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
|
||||
@@ -581,14 +723,16 @@ public class OllamaAPI {
|
||||
* Convenience method to call Ollama API without streaming responses.
|
||||
* <p>
|
||||
* Uses {@link #generateWithImageURLs(String, String, List, Options, OllamaStreamHandler)}
|
||||
*
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws URISyntaxException if the URI for the request is malformed
|
||||
*/
|
||||
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs,
|
||||
Options options)
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
return generateWithImageURLs(model, prompt, imageURLs, options, null);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api
|
||||
* 'api/chat'.
|
||||
@@ -599,6 +743,9 @@ public class OllamaAPI {
|
||||
* @throws OllamaBaseException any response code than 200 has been returned
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaChatResult chat(String model, List<OllamaChatMessage> messages) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model);
|
||||
@@ -615,6 +762,9 @@ public class OllamaAPI {
|
||||
* @throws OllamaBaseException any response code than 200 has been returned
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaChatResult chat(OllamaChatRequest request) throws OllamaBaseException, IOException, InterruptedException {
|
||||
return chat(request, null);
|
||||
@@ -631,23 +781,189 @@ public class OllamaAPI {
|
||||
* @throws OllamaBaseException any response code than 200 has been returned
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
|
||||
}
|
||||
|
||||
/**
|
||||
* Ask a question to a model using an {@link OllamaChatRequest}. This can be constructed using an {@link OllamaChatRequestBuilder}.
|
||||
* <p>
|
||||
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
|
||||
*
|
||||
* @param request request object to be sent to the server
|
||||
* @param tokenHandler callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated)
|
||||
* @return {@link OllamaChatResult}
|
||||
* @throws OllamaBaseException any response code than 200 has been returned
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||
OllamaResult result;
|
||||
if (streamHandler != null) {
|
||||
OllamaChatResult result;
|
||||
|
||||
// add all registered tools to Request
|
||||
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
|
||||
|
||||
if (tokenHandler != null) {
|
||||
request.setStream(true);
|
||||
result = requestCaller.call(request, streamHandler);
|
||||
result = requestCaller.call(request, tokenHandler);
|
||||
} else {
|
||||
result = requestCaller.callSync(request);
|
||||
}
|
||||
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
|
||||
|
||||
// check if toolCallIsWanted
|
||||
List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||
int toolCallTries = 0;
|
||||
while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries){
|
||||
for (OllamaChatToolCalls toolCall : toolCalls){
|
||||
String toolName = toolCall.getFunction().getName();
|
||||
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
|
||||
Map<String, Object> arguments = toolCall.getFunction().getArguments();
|
||||
Object res = toolFunction.apply(arguments);
|
||||
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
|
||||
}
|
||||
|
||||
if (tokenHandler != null) {
|
||||
result = requestCaller.call(request, tokenHandler);
|
||||
} else {
|
||||
result = requestCaller.callSync(request);
|
||||
}
|
||||
toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||
toolCallTries++;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public void registerTool(Tools.ToolSpecification toolSpecification) {
|
||||
toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
|
||||
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
|
||||
}
|
||||
|
||||
|
||||
public void registerAnnotatedTools() {
|
||||
try {
|
||||
Class<?> callerClass = null;
|
||||
try {
|
||||
callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
|
||||
} catch (ClassNotFoundException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class);
|
||||
if (ollamaToolServiceAnnotation == null) {
|
||||
throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class);
|
||||
}
|
||||
|
||||
Class<?>[] providers = ollamaToolServiceAnnotation.providers();
|
||||
for (Class<?> provider : providers) {
|
||||
registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
|
||||
}
|
||||
} catch (InstantiationException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public void registerAnnotatedTools(Object object) {
|
||||
Class<?> objectClass = object.getClass();
|
||||
Method[] methods = objectClass.getMethods();
|
||||
for(Method m : methods) {
|
||||
ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
|
||||
if(toolSpec == null){
|
||||
continue;
|
||||
}
|
||||
String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
|
||||
String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
|
||||
|
||||
final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
|
||||
LinkedHashMap<String,String> methodParams = new LinkedHashMap<>();
|
||||
for (Parameter parameter : m.getParameters()) {
|
||||
final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class);
|
||||
String propType = parameter.getType().getTypeName();
|
||||
if(toolPropertyAnn == null) {
|
||||
methodParams.put(parameter.getName(),null);
|
||||
continue;
|
||||
}
|
||||
String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
|
||||
methodParams.put(propName,propType);
|
||||
propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder()
|
||||
.type(propType)
|
||||
.description(toolPropertyAnn.desc())
|
||||
.required(toolPropertyAnn.required())
|
||||
.build());
|
||||
}
|
||||
final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
|
||||
List<String> reqProps = params.entrySet().stream()
|
||||
.filter(e -> e.getValue().isRequired())
|
||||
.map(Map.Entry::getKey)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder()
|
||||
.functionName(operationName)
|
||||
.functionDescription(operationDesc)
|
||||
.toolPrompt(
|
||||
Tools.PromptFuncDefinition.builder().type("function").function(
|
||||
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||
.name(operationName)
|
||||
.description(operationDesc)
|
||||
.parameters(
|
||||
Tools.PromptFuncDefinition.Parameters.builder()
|
||||
.type("object")
|
||||
.properties(
|
||||
params
|
||||
)
|
||||
.required(reqProps)
|
||||
.build()
|
||||
).build()
|
||||
).build()
|
||||
)
|
||||
.build();
|
||||
|
||||
ReflectionalToolFunction reflectionalToolFunction =
|
||||
new ReflectionalToolFunction(object, m, methodParams);
|
||||
toolSpecification.setToolFunction(reflectionalToolFunction);
|
||||
toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a custom role.
|
||||
*
|
||||
* @param roleName the name of the custom role to be added
|
||||
* @return the newly created OllamaChatMessageRole
|
||||
*/
|
||||
public OllamaChatMessageRole addCustomRole(String roleName) {
|
||||
return OllamaChatMessageRole.newCustomRole(roleName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Lists all available roles.
|
||||
*
|
||||
* @return a list of available OllamaChatMessageRole objects
|
||||
*/
|
||||
public List<OllamaChatMessageRole> listRoles() {
|
||||
return OllamaChatMessageRole.getRoles();
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a specific role by name.
|
||||
*
|
||||
* @param roleName the name of the role to retrieve
|
||||
* @return the OllamaChatMessageRole associated with the given name
|
||||
* @throws RoleNotFoundException if the role with the specified name does not exist
|
||||
*/
|
||||
public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
|
||||
return OllamaChatMessageRole.getRole(roleName);
|
||||
}
|
||||
|
||||
|
||||
// technical private methods //
|
||||
|
||||
private static String encodeFileToBase64(File file) throws IOException {
|
||||
@@ -658,11 +974,8 @@ public class OllamaAPI {
|
||||
return Base64.getEncoder().encodeToString(bytes);
|
||||
}
|
||||
|
||||
private OllamaResult generateSyncForOllamaRequestModel(
|
||||
OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaGenerateEndpointCaller requestCaller =
|
||||
new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||
OllamaResult result;
|
||||
if (streamHandler != null) {
|
||||
ollamaRequestModel.setStream(true);
|
||||
@@ -680,10 +993,7 @@ public class OllamaAPI {
|
||||
* @return HttpRequest.Builder
|
||||
*/
|
||||
private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
|
||||
HttpRequest.Builder requestBuilder =
|
||||
HttpRequest.newBuilder(uri)
|
||||
.header("Content-Type", "application/json")
|
||||
.timeout(Duration.ofSeconds(requestTimeoutSeconds));
|
||||
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds));
|
||||
if (isBasicAuthCredentialsSet()) {
|
||||
requestBuilder.header("Authorization", getBasicAuthHeaderValue());
|
||||
}
|
||||
@@ -709,12 +1019,11 @@ public class OllamaAPI {
|
||||
return basicAuth != null;
|
||||
}
|
||||
|
||||
|
||||
private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
|
||||
try {
|
||||
String methodName = toolFunctionCallSpec.getName();
|
||||
Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
|
||||
ToolFunction function = toolRegistry.getFunction(methodName);
|
||||
ToolFunction function = toolRegistry.getToolFunction(methodName);
|
||||
if (verbose) {
|
||||
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
package io.github.ollama4j.exceptions;
|
||||
|
||||
public class RoleNotFoundException extends Exception {
|
||||
|
||||
public RoleNotFoundException(String s) {
|
||||
super(s);
|
||||
}
|
||||
}
|
||||
@@ -2,12 +2,14 @@ package io.github.ollama4j.models.chat;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
import io.github.ollama4j.utils.FileToBase64Serializer;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
@@ -31,15 +33,17 @@ public class OllamaChatMessage {
|
||||
@NonNull
|
||||
private String content;
|
||||
|
||||
private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
|
||||
|
||||
@JsonSerialize(using = FileToBase64Serializer.class)
|
||||
private List<byte[]> images;
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
try {
|
||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
try {
|
||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,53 @@
|
||||
package io.github.ollama4j.models.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import io.github.ollama4j.exceptions.RoleNotFoundException;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Defines the possible Chat Message roles.
|
||||
*/
|
||||
public enum OllamaChatMessageRole {
|
||||
SYSTEM("system"),
|
||||
USER("user"),
|
||||
ASSISTANT("assistant"),
|
||||
TOOL("tool");
|
||||
@Getter
|
||||
public class OllamaChatMessageRole {
|
||||
private static final List<OllamaChatMessageRole> roles = new ArrayList<>();
|
||||
|
||||
public static final OllamaChatMessageRole SYSTEM = new OllamaChatMessageRole("system");
|
||||
public static final OllamaChatMessageRole USER = new OllamaChatMessageRole("user");
|
||||
public static final OllamaChatMessageRole ASSISTANT = new OllamaChatMessageRole("assistant");
|
||||
public static final OllamaChatMessageRole TOOL = new OllamaChatMessageRole("tool");
|
||||
|
||||
@JsonValue
|
||||
private String roleName;
|
||||
private final String roleName;
|
||||
|
||||
private OllamaChatMessageRole(String roleName){
|
||||
private OllamaChatMessageRole(String roleName) {
|
||||
this.roleName = roleName;
|
||||
roles.add(this);
|
||||
}
|
||||
|
||||
public static OllamaChatMessageRole newCustomRole(String roleName) {
|
||||
OllamaChatMessageRole customRole = new OllamaChatMessageRole(roleName);
|
||||
roles.add(customRole);
|
||||
return customRole;
|
||||
}
|
||||
|
||||
public static List<OllamaChatMessageRole> getRoles() {
|
||||
return new ArrayList<>(roles);
|
||||
}
|
||||
|
||||
public static OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
|
||||
for (OllamaChatMessageRole role : roles) {
|
||||
if (role.roleName.equals(roleName)) {
|
||||
return role;
|
||||
}
|
||||
}
|
||||
throw new RoleNotFoundException("Invalid role name: " + roleName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return roleName;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package io.github.ollama4j.models.chat;
|
||||
import java.util.List;
|
||||
|
||||
import io.github.ollama4j.models.request.OllamaCommonRequest;
|
||||
import io.github.ollama4j.tools.Tools;
|
||||
import io.github.ollama4j.utils.OllamaRequestBody;
|
||||
|
||||
import lombok.Getter;
|
||||
@@ -21,6 +22,8 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ
|
||||
|
||||
private List<OllamaChatMessage> messages;
|
||||
|
||||
private List<Tools.PromptFuncDefinition> tools;
|
||||
|
||||
public OllamaChatRequest() {}
|
||||
|
||||
public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
|
||||
|
||||
@@ -10,6 +10,7 @@ import java.io.IOException;
|
||||
import java.net.URISyntaxException;
|
||||
import java.nio.file.Files;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -38,23 +39,27 @@ public class OllamaChatRequestBuilder {
|
||||
request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
|
||||
}
|
||||
|
||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images) {
|
||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){
|
||||
return withMessage(role,content, Collections.emptyList());
|
||||
}
|
||||
|
||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
|
||||
List<OllamaChatMessage> messages = this.request.getMessages();
|
||||
|
||||
List<byte[]> binaryImages = images.stream().map(file -> {
|
||||
try {
|
||||
return Files.readAllBytes(file.toPath());
|
||||
} catch (IOException e) {
|
||||
LOG.warn(String.format("File '%s' could not be accessed, will not add to message!", file.toPath()), e);
|
||||
LOG.warn("File '{}' could not be accessed, will not add to message!", file.toPath(), e);
|
||||
return new byte[0];
|
||||
}
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
messages.add(new OllamaChatMessage(role, content, binaryImages));
|
||||
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
|
||||
return this;
|
||||
}
|
||||
|
||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) {
|
||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List<OllamaChatToolCalls> toolCalls, String... imageUrls) {
|
||||
List<OllamaChatMessage> messages = this.request.getMessages();
|
||||
List<byte[]> binaryImages = null;
|
||||
if (imageUrls.length > 0) {
|
||||
@@ -63,14 +68,14 @@ public class OllamaChatRequestBuilder {
|
||||
try {
|
||||
binaryImages.add(Utils.loadImageBytesFromUrl(imageUrl));
|
||||
} catch (URISyntaxException e) {
|
||||
LOG.warn(String.format("URL '%s' could not be accessed, will not add to message!", imageUrl), e);
|
||||
LOG.warn("URL '{}' could not be accessed, will not add to message!", imageUrl, e);
|
||||
} catch (IOException e) {
|
||||
LOG.warn(String.format("Content of URL '%s' could not be read, will not add to message!", imageUrl), e);
|
||||
LOG.warn("Content of URL '{}' could not be read, will not add to message!", imageUrl, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages.add(new OllamaChatMessage(role, content, binaryImages));
|
||||
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,31 +2,54 @@ package io.github.ollama4j.models.chat;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import io.github.ollama4j.models.response.OllamaResult;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import lombok.Getter;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
/**
|
||||
* Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
|
||||
* {@link OllamaChatMessageRole#ASSISTANT} role.
|
||||
*/
|
||||
public class OllamaChatResult extends OllamaResult{
|
||||
@Getter
|
||||
public class OllamaChatResult {
|
||||
|
||||
|
||||
private List<OllamaChatMessage> chatHistory;
|
||||
|
||||
public OllamaChatResult(String response, long responseTime, int httpStatusCode,
|
||||
List<OllamaChatMessage> chatHistory) {
|
||||
super(response, responseTime, httpStatusCode);
|
||||
private OllamaChatResponseModel responseModel;
|
||||
|
||||
public OllamaChatResult(OllamaChatResponseModel responseModel, List<OllamaChatMessage> chatHistory) {
|
||||
this.chatHistory = chatHistory;
|
||||
appendAnswerToChatHistory(response);
|
||||
this.responseModel = responseModel;
|
||||
appendAnswerToChatHistory(responseModel);
|
||||
}
|
||||
|
||||
public List<OllamaChatMessage> getChatHistory() {
|
||||
return chatHistory;
|
||||
}
|
||||
|
||||
private void appendAnswerToChatHistory(String answer){
|
||||
OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer);
|
||||
this.chatHistory.add(assistantMessage);
|
||||
private void appendAnswerToChatHistory(OllamaChatResponseModel response) {
|
||||
this.chatHistory.add(response.getMessage());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
try {
|
||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public String getResponse(){
|
||||
return responseModel != null ? responseModel.getMessage().getContent() : "";
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public int getHttpStatusCode(){
|
||||
return 200;
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public long getResponseTime(){
|
||||
return responseModel != null ? responseModel.getTotalDuration() : 0L;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -1,31 +1,19 @@
|
||||
package io.github.ollama4j.models.chat;
|
||||
|
||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||
import io.github.ollama4j.models.generate.OllamaTokenHandler;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class OllamaChatStreamObserver {
|
||||
|
||||
private OllamaStreamHandler streamHandler;
|
||||
|
||||
private List<OllamaChatResponseModel> responseParts = new ArrayList<>();
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class OllamaChatStreamObserver implements OllamaTokenHandler {
|
||||
private final OllamaStreamHandler streamHandler;
|
||||
private String message = "";
|
||||
|
||||
public OllamaChatStreamObserver(OllamaStreamHandler streamHandler) {
|
||||
this.streamHandler = streamHandler;
|
||||
@Override
|
||||
public void accept(OllamaChatResponseModel token) {
|
||||
if (streamHandler != null) {
|
||||
message += token.getMessage().getContent();
|
||||
streamHandler.accept(message);
|
||||
}
|
||||
}
|
||||
|
||||
public void notify(OllamaChatResponseModel currentResponsePart) {
|
||||
responseParts.add(currentResponsePart);
|
||||
handleCurrentResponsePart(currentResponsePart);
|
||||
}
|
||||
|
||||
protected void handleCurrentResponsePart(OllamaChatResponseModel currentResponsePart) {
|
||||
message = message + currentResponsePart.getMessage().getContent();
|
||||
streamHandler.accept(message);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package io.github.ollama4j.models.chat;
|
||||
|
||||
import io.github.ollama4j.tools.OllamaToolCallsFunction;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class OllamaChatToolCalls {
|
||||
|
||||
private OllamaToolCallsFunction function;
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package io.github.ollama4j.models.embeddings;
|
||||
|
||||
import io.github.ollama4j.utils.Options;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Builderclass to easily create Requests for Embedding models using ollama.
|
||||
*/
|
||||
public class OllamaEmbedRequestBuilder {
|
||||
|
||||
private final OllamaEmbedRequestModel request;
|
||||
|
||||
private OllamaEmbedRequestBuilder(String model, List<String> input) {
|
||||
this.request = new OllamaEmbedRequestModel(model,input);
|
||||
}
|
||||
|
||||
public static OllamaEmbedRequestBuilder getInstance(String model, String... input){
|
||||
return new OllamaEmbedRequestBuilder(model, List.of(input));
|
||||
}
|
||||
|
||||
public OllamaEmbedRequestBuilder withOptions(Options options){
|
||||
this.request.setOptions(options.getOptionsMap());
|
||||
return this;
|
||||
}
|
||||
|
||||
public OllamaEmbedRequestBuilder withKeepAlive(String keepAlive){
|
||||
this.request.setKeepAlive(keepAlive);
|
||||
return this;
|
||||
}
|
||||
|
||||
public OllamaEmbedRequestBuilder withoutTruncate(){
|
||||
this.request.setTruncate(false);
|
||||
return this;
|
||||
}
|
||||
|
||||
public OllamaEmbedRequestModel build() {
|
||||
return this.request;
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import lombok.Data;
|
||||
|
||||
@SuppressWarnings("unused")
|
||||
@Data
|
||||
@Deprecated(since="1.0.90")
|
||||
public class OllamaEmbeddingResponseModel {
|
||||
@JsonProperty("embedding")
|
||||
private List<Double> embedding;
|
||||
|
||||
@@ -2,6 +2,7 @@ package io.github.ollama4j.models.embeddings;
|
||||
|
||||
import io.github.ollama4j.utils.Options;
|
||||
|
||||
@Deprecated(since="1.0.90")
|
||||
public class OllamaEmbeddingsRequestBuilder {
|
||||
|
||||
private OllamaEmbeddingsRequestBuilder(String model, String prompt){
|
||||
|
||||
@@ -12,6 +12,7 @@ import lombok.RequiredArgsConstructor;
|
||||
@Data
|
||||
@RequiredArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@Deprecated(since="1.0.90")
|
||||
public class OllamaEmbeddingsRequestModel {
|
||||
@NonNull
|
||||
private String model;
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
package io.github.ollama4j.models.generate;
|
||||
|
||||
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
|
||||
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public interface OllamaTokenHandler extends Consumer<OllamaChatResponseModel> {
|
||||
}
|
||||
@@ -1,17 +1,25 @@
|
||||
package io.github.ollama4j.models.request;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
import io.github.ollama4j.models.response.OllamaResult;
|
||||
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
|
||||
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
|
||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||
import io.github.ollama4j.utils.OllamaRequestBody;
|
||||
import io.github.ollama4j.models.chat.*;
|
||||
import io.github.ollama4j.models.generate.OllamaTokenHandler;
|
||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
||||
import io.github.ollama4j.utils.Utils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Specialization class for requests
|
||||
@@ -20,7 +28,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
|
||||
|
||||
private OllamaChatStreamObserver streamObserver;
|
||||
private OllamaTokenHandler tokenHandler;
|
||||
|
||||
public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
||||
super(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||
@@ -31,13 +39,29 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
||||
return "/api/chat";
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses streamed Response line from ollama chat.
|
||||
* Using {@link com.fasterxml.jackson.databind.ObjectMapper#readValue(String, TypeReference)} should throw
|
||||
* {@link IllegalArgumentException} in case of null line or {@link com.fasterxml.jackson.core.JsonParseException}
|
||||
* in case the JSON Object cannot be parsed to a {@link OllamaChatResponseModel}. Thus, the ResponseModel should
|
||||
* never be null.
|
||||
*
|
||||
* @param line streamed line of ollama stream response
|
||||
* @param responseBuffer Stringbuffer to add latest response message part to
|
||||
* @return TRUE, if ollama-Response has 'done' state
|
||||
*/
|
||||
@Override
|
||||
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
|
||||
try {
|
||||
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
|
||||
responseBuffer.append(ollamaResponseModel.getMessage().getContent());
|
||||
if (streamObserver != null) {
|
||||
streamObserver.notify(ollamaResponseModel);
|
||||
// it seems that under heavy load ollama responds with an empty chat message part in the streamed response
|
||||
// thus, we null check the message and hope that the next streamed response has some message content again
|
||||
OllamaChatMessage message = ollamaResponseModel.getMessage();
|
||||
if(message != null) {
|
||||
responseBuffer.append(message.getContent());
|
||||
if (tokenHandler != null) {
|
||||
tokenHandler.accept(ollamaResponseModel);
|
||||
}
|
||||
}
|
||||
return ollamaResponseModel.isDone();
|
||||
} catch (JsonProcessingException e) {
|
||||
@@ -46,9 +70,75 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
||||
}
|
||||
}
|
||||
|
||||
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
||||
public OllamaChatResult call(OllamaChatRequest body, OllamaTokenHandler tokenHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
streamObserver = new OllamaChatStreamObserver(streamHandler);
|
||||
return super.callSync(body);
|
||||
this.tokenHandler = tokenHandler;
|
||||
return callSync(body);
|
||||
}
|
||||
|
||||
public OllamaChatResult callSync(OllamaChatRequest body) throws OllamaBaseException, IOException, InterruptedException {
|
||||
// Create Request
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
URI uri = URI.create(getHost() + getEndpointSuffix());
|
||||
HttpRequest.Builder requestBuilder =
|
||||
getRequestBuilderDefault(uri)
|
||||
.POST(
|
||||
body.getBodyPublisher());
|
||||
HttpRequest request = requestBuilder.build();
|
||||
if (isVerbose()) LOG.info("Asking model: " + body);
|
||||
HttpResponse<InputStream> response =
|
||||
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
|
||||
int statusCode = response.statusCode();
|
||||
InputStream responseBodyStream = response.body();
|
||||
StringBuilder responseBuffer = new StringBuilder();
|
||||
OllamaChatResponseModel ollamaChatResponseModel = null;
|
||||
List<OllamaChatToolCalls> wantedToolsForStream = null;
|
||||
try (BufferedReader reader =
|
||||
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (statusCode == 404) {
|
||||
LOG.warn("Status code: 404 (Not Found)");
|
||||
OllamaErrorResponse ollamaResponseModel =
|
||||
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else if (statusCode == 401) {
|
||||
LOG.warn("Status code: 401 (Unauthorized)");
|
||||
OllamaErrorResponse ollamaResponseModel =
|
||||
Utils.getObjectMapper()
|
||||
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else if (statusCode == 400) {
|
||||
LOG.warn("Status code: 400 (Bad Request)");
|
||||
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
|
||||
OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else {
|
||||
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
||||
ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
|
||||
if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){
|
||||
wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
|
||||
}
|
||||
if (finished && body.stream) {
|
||||
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (statusCode != 200) {
|
||||
LOG.error("Status code " + statusCode);
|
||||
throw new OllamaBaseException(responseBuffer.toString());
|
||||
} else {
|
||||
if(wantedToolsForStream != null) {
|
||||
ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
|
||||
}
|
||||
OllamaChatResult ollamaResult =
|
||||
new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
|
||||
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
|
||||
return ollamaResult;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import io.github.ollama4j.models.response.OllamaErrorResponse;
|
||||
import io.github.ollama4j.models.response.OllamaResult;
|
||||
import io.github.ollama4j.utils.OllamaRequestBody;
|
||||
import io.github.ollama4j.utils.Utils;
|
||||
import lombok.Getter;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -24,14 +25,15 @@ import java.util.Base64;
|
||||
/**
|
||||
* Abstract helperclass to call the ollama api server.
|
||||
*/
|
||||
@Getter
|
||||
public abstract class OllamaEndpointCaller {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
|
||||
|
||||
private String host;
|
||||
private BasicAuth basicAuth;
|
||||
private long requestTimeoutSeconds;
|
||||
private boolean verbose;
|
||||
private final String host;
|
||||
private final BasicAuth basicAuth;
|
||||
private final long requestTimeoutSeconds;
|
||||
private final boolean verbose;
|
||||
|
||||
public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
||||
this.host = host;
|
||||
@@ -45,80 +47,13 @@ public abstract class OllamaEndpointCaller {
|
||||
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
|
||||
|
||||
|
||||
/**
|
||||
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
|
||||
*
|
||||
* @param body POST body payload
|
||||
* @return result answer given by the assistant
|
||||
* @throws OllamaBaseException any response code than 200 has been returned
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||
*/
|
||||
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
|
||||
// Create Request
|
||||
long startTime = System.currentTimeMillis();
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
URI uri = URI.create(this.host + getEndpointSuffix());
|
||||
HttpRequest.Builder requestBuilder =
|
||||
getRequestBuilderDefault(uri)
|
||||
.POST(
|
||||
body.getBodyPublisher());
|
||||
HttpRequest request = requestBuilder.build();
|
||||
if (this.verbose) LOG.info("Asking model: " + body.toString());
|
||||
HttpResponse<InputStream> response =
|
||||
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
|
||||
int statusCode = response.statusCode();
|
||||
InputStream responseBodyStream = response.body();
|
||||
StringBuilder responseBuffer = new StringBuilder();
|
||||
try (BufferedReader reader =
|
||||
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (statusCode == 404) {
|
||||
LOG.warn("Status code: 404 (Not Found)");
|
||||
OllamaErrorResponse ollamaResponseModel =
|
||||
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else if (statusCode == 401) {
|
||||
LOG.warn("Status code: 401 (Unauthorized)");
|
||||
OllamaErrorResponse ollamaResponseModel =
|
||||
Utils.getObjectMapper()
|
||||
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else if (statusCode == 400) {
|
||||
LOG.warn("Status code: 400 (Bad Request)");
|
||||
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
|
||||
OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else {
|
||||
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
||||
if (finished) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (statusCode != 200) {
|
||||
LOG.error("Status code " + statusCode);
|
||||
throw new OllamaBaseException(responseBuffer.toString());
|
||||
} else {
|
||||
long endTime = System.currentTimeMillis();
|
||||
OllamaResult ollamaResult =
|
||||
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
|
||||
if (verbose) LOG.info("Model response: " + ollamaResult);
|
||||
return ollamaResult;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get default request builder.
|
||||
*
|
||||
* @param uri URI to get a HttpRequest.Builder
|
||||
* @return HttpRequest.Builder
|
||||
*/
|
||||
private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
|
||||
protected HttpRequest.Builder getRequestBuilderDefault(URI uri) {
|
||||
HttpRequest.Builder requestBuilder =
|
||||
HttpRequest.newBuilder(uri)
|
||||
.header("Content-Type", "application/json")
|
||||
@@ -134,7 +69,7 @@ public abstract class OllamaEndpointCaller {
|
||||
*
|
||||
* @return basic authentication header value (encoded credentials)
|
||||
*/
|
||||
private String getBasicAuthHeaderValue() {
|
||||
protected String getBasicAuthHeaderValue() {
|
||||
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
|
||||
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
|
||||
}
|
||||
@@ -144,7 +79,7 @@ public abstract class OllamaEndpointCaller {
|
||||
*
|
||||
* @return true when Basic Auth credentials set
|
||||
*/
|
||||
private boolean isBasicAuthCredentialsSet() {
|
||||
protected boolean isBasicAuthCredentialsSet() {
|
||||
return this.basicAuth != null;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package io.github.ollama4j.models.request;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
||||
import io.github.ollama4j.models.response.OllamaResult;
|
||||
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
|
||||
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
|
||||
@@ -11,7 +12,15 @@ import io.github.ollama4j.utils.Utils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
||||
|
||||
@@ -46,6 +55,73 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
||||
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
|
||||
return super.callSync(body);
|
||||
return callSync(body);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
|
||||
*
|
||||
* @param body POST body payload
|
||||
* @return result answer given by the assistant
|
||||
* @throws OllamaBaseException any response code than 200 has been returned
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||
*/
|
||||
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
|
||||
// Create Request
|
||||
long startTime = System.currentTimeMillis();
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
URI uri = URI.create(getHost() + getEndpointSuffix());
|
||||
HttpRequest.Builder requestBuilder =
|
||||
getRequestBuilderDefault(uri)
|
||||
.POST(
|
||||
body.getBodyPublisher());
|
||||
HttpRequest request = requestBuilder.build();
|
||||
if (isVerbose()) LOG.info("Asking model: " + body.toString());
|
||||
HttpResponse<InputStream> response =
|
||||
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
|
||||
int statusCode = response.statusCode();
|
||||
InputStream responseBodyStream = response.body();
|
||||
StringBuilder responseBuffer = new StringBuilder();
|
||||
try (BufferedReader reader =
|
||||
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (statusCode == 404) {
|
||||
LOG.warn("Status code: 404 (Not Found)");
|
||||
OllamaErrorResponse ollamaResponseModel =
|
||||
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else if (statusCode == 401) {
|
||||
LOG.warn("Status code: 401 (Unauthorized)");
|
||||
OllamaErrorResponse ollamaResponseModel =
|
||||
Utils.getObjectMapper()
|
||||
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else if (statusCode == 400) {
|
||||
LOG.warn("Status code: 400 (Bad Request)");
|
||||
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
|
||||
OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else {
|
||||
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
||||
if (finished) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (statusCode != 200) {
|
||||
LOG.error("Status code " + statusCode);
|
||||
throw new OllamaBaseException(responseBuffer.toString());
|
||||
} else {
|
||||
long endTime = System.currentTimeMillis();
|
||||
OllamaResult ollamaResult =
|
||||
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
|
||||
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
|
||||
return ollamaResult;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package io.github.ollama4j.models.response;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class LibraryModel {
|
||||
|
||||
private String name;
|
||||
private String description;
|
||||
private String pullCount;
|
||||
private int totalTags;
|
||||
private List<String> popularTags = new ArrayList<>();
|
||||
private String lastUpdated;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package io.github.ollama4j.models.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class LibraryModelDetail {
|
||||
|
||||
private LibraryModel model;
|
||||
private List<LibraryModelTag> tags;
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package io.github.ollama4j.models.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class LibraryModelTag {
|
||||
private String name;
|
||||
private String tag;
|
||||
private String size;
|
||||
private String lastUpdated;
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package io.github.ollama4j.tools;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class OllamaToolCallsFunction
|
||||
{
|
||||
private String name;
|
||||
private Map<String,Object> arguments;
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package io.github.ollama4j.tools;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Specification of a {@link ToolFunction} that provides the implementation via java reflection calling.
|
||||
*/
|
||||
@Setter
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public class ReflectionalToolFunction implements ToolFunction{
|
||||
|
||||
private Object functionHolder;
|
||||
private Method function;
|
||||
private LinkedHashMap<String,String> propertyDefinition;
|
||||
|
||||
@Override
|
||||
public Object apply(Map<String, Object> arguments) {
|
||||
LinkedHashMap<String, Object> argumentsCopy = new LinkedHashMap<>(this.propertyDefinition);
|
||||
for (Map.Entry<String,String> param : this.propertyDefinition.entrySet()){
|
||||
argumentsCopy.replace(param.getKey(),typeCast(arguments.get(param.getKey()),param.getValue()));
|
||||
}
|
||||
try {
|
||||
return function.invoke(functionHolder, argumentsCopy.values().toArray());
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to invoke tool: " + function.getName(), e);
|
||||
}
|
||||
}
|
||||
|
||||
private Object typeCast(Object inputValue, String className) {
|
||||
if(className == null || inputValue == null) {
|
||||
return null;
|
||||
}
|
||||
String inputValueString = inputValue.toString();
|
||||
switch (className) {
|
||||
case "java.lang.Integer":
|
||||
return Integer.parseInt(inputValueString);
|
||||
case "java.lang.Boolean":
|
||||
return Boolean.valueOf(inputValueString);
|
||||
case "java.math.BigDecimal":
|
||||
return new BigDecimal(inputValueString);
|
||||
default:
|
||||
return inputValueString;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,16 +1,22 @@
|
||||
package io.github.ollama4j.tools;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class ToolRegistry {
|
||||
private final Map<String, ToolFunction> functionMap = new HashMap<>();
|
||||
private final Map<String, Tools.ToolSpecification> tools = new HashMap<>();
|
||||
|
||||
public ToolFunction getFunction(String name) {
|
||||
return functionMap.get(name);
|
||||
public ToolFunction getToolFunction(String name) {
|
||||
final Tools.ToolSpecification toolSpecification = tools.get(name);
|
||||
return toolSpecification !=null ? toolSpecification.getToolFunction() : null ;
|
||||
}
|
||||
|
||||
public void addFunction(String name, ToolFunction function) {
|
||||
functionMap.put(name, function);
|
||||
public void addTool (String name, Tools.ToolSpecification specification) {
|
||||
tools.put(name, specification);
|
||||
}
|
||||
|
||||
public Collection<Tools.ToolSpecification> getRegisteredSpecs(){
|
||||
return tools.values();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,8 +6,10 @@ import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import io.github.ollama4j.utils.Utils;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
@@ -20,17 +22,23 @@ public class Tools {
|
||||
public static class ToolSpecification {
|
||||
private String functionName;
|
||||
private String functionDescription;
|
||||
private Map<String, PromptFuncDefinition.Property> properties;
|
||||
private ToolFunction toolDefinition;
|
||||
private PromptFuncDefinition toolPrompt;
|
||||
private ToolFunction toolFunction;
|
||||
}
|
||||
|
||||
@Data
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class PromptFuncDefinition {
|
||||
private String type;
|
||||
private PromptFuncSpec function;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class PromptFuncSpec {
|
||||
private String name;
|
||||
private String description;
|
||||
@@ -38,6 +46,9 @@ public class Tools {
|
||||
}
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class Parameters {
|
||||
private String type;
|
||||
private Map<String, Property> properties;
|
||||
@@ -46,6 +57,8 @@ public class Tools {
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public static class Property {
|
||||
private String type;
|
||||
private String description;
|
||||
@@ -94,10 +107,10 @@ public class Tools {
|
||||
|
||||
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
|
||||
parameters.setType("object");
|
||||
parameters.setProperties(spec.getProperties());
|
||||
parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties());
|
||||
|
||||
List<String> requiredValues = new ArrayList<>();
|
||||
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) {
|
||||
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) {
|
||||
if (p.getValue().isRequired()) {
|
||||
requiredValues.add(p.getKey());
|
||||
}
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package io.github.ollama4j.tools.annotations;
|
||||
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* Annotates a class that calls {@link io.github.ollama4j.OllamaAPI} such that the Method
|
||||
* {@link OllamaAPI#registerAnnotatedTools()} can be used to auto-register all provided classes (resp. all
|
||||
* contained Methods of the provider classes annotated with {@link ToolSpec}).
|
||||
*/
|
||||
@Target(ElementType.TYPE)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
public @interface OllamaToolService {
|
||||
|
||||
/**
|
||||
* @return Classes with no-arg constructor that will be used for tool-registration.
|
||||
*/
|
||||
Class<?>[] providers();
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package io.github.ollama4j.tools.annotations;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* Annotates a Method Parameter in a {@link ToolSpec} annotated Method. A parameter annotated with this annotation will
|
||||
* be part of the tool description that is sent to the llm for tool-calling.
|
||||
*/
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@Target(ElementType.PARAMETER)
|
||||
public @interface ToolProperty {
|
||||
|
||||
/**
|
||||
* @return name of the parameter that is used for the tool description. Has to be set as depending on the caller,
|
||||
* method name backtracking is not possible with reflection.
|
||||
*/
|
||||
String name();
|
||||
|
||||
/**
|
||||
* @return a detailed description of the parameter. This is used by the llm called to specify, which property has
|
||||
* to be set by the llm and how this should be filled.
|
||||
*/
|
||||
String desc();
|
||||
|
||||
/**
|
||||
* @return tells the llm that it has to set a value for this property.
|
||||
*/
|
||||
boolean required() default true;
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package io.github.ollama4j.tools.annotations;
|
||||
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* Annotates Methods of classes that should be registered as tools by {@link OllamaAPI#registerAnnotatedTools()}
|
||||
* automatically.
|
||||
*/
|
||||
@Target(ElementType.METHOD)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
public @interface ToolSpec {
|
||||
|
||||
/**
|
||||
* @return tool-name that the method should be used as. Defaults to the methods name.
|
||||
*/
|
||||
String name() default "";
|
||||
|
||||
/**
|
||||
* @return a detailed description of the method that can be interpreted by the llm, whether it should call the tool
|
||||
* or not.
|
||||
*/
|
||||
String desc();
|
||||
}
|
||||
@@ -15,6 +15,7 @@ public class OllamaModelType {
|
||||
public static final String LLAMA3_1 = "llama3.1";
|
||||
public static final String MISTRAL = "mistral";
|
||||
public static final String MIXTRAL = "mixtral";
|
||||
public static final String DEEPSEEK_R1 = "deepseek-r1";
|
||||
public static final String LLAVA = "llava";
|
||||
public static final String LLAVA_PHI3 = "llava-phi3";
|
||||
public static final String NEURAL_CHAT = "neural-chat";
|
||||
|
||||
@@ -2,14 +2,16 @@ package io.github.ollama4j.integrationtests;
|
||||
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
import io.github.ollama4j.models.chat.*;
|
||||
import io.github.ollama4j.models.response.ModelDetail;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||
import io.github.ollama4j.models.response.OllamaResult;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
||||
import io.github.ollama4j.models.chat.OllamaChatResult;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||
import io.github.ollama4j.samples.AnnotatedTool;
|
||||
import io.github.ollama4j.tools.OllamaToolCallsFunction;
|
||||
import io.github.ollama4j.tools.ToolFunction;
|
||||
import io.github.ollama4j.tools.Tools;
|
||||
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
import lombok.Data;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
@@ -24,12 +26,12 @@ import java.io.InputStream;
|
||||
import java.net.ConnectException;
|
||||
import java.net.URISyntaxException;
|
||||
import java.net.http.HttpConnectTimeoutException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Properties;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@OllamaToolService(providers = {AnnotatedTool.class}
|
||||
)
|
||||
class TestRealAPIs {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
|
||||
@@ -47,6 +49,7 @@ class TestRealAPIs {
|
||||
config = new Config();
|
||||
ollamaAPI = new OllamaAPI(config.getOllamaURL());
|
||||
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
|
||||
ollamaAPI.setVerbose(true);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -80,6 +83,18 @@ class TestRealAPIs {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
void testListModelsFromLibrary() {
|
||||
testEndpointReachability();
|
||||
try {
|
||||
assertNotNull(ollamaAPI.listModelsFromLibrary());
|
||||
ollamaAPI.listModelsFromLibrary().forEach(System.out::println);
|
||||
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
void testPullModel() {
|
||||
@@ -184,7 +199,9 @@ class TestRealAPIs {
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertFalse(chatResult.getResponse().isBlank());
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
|
||||
assertEquals(4, chatResult.getChatHistory().size());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
fail(e);
|
||||
@@ -205,14 +222,211 @@ class TestRealAPIs {
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertFalse(chatResult.getResponse().isBlank());
|
||||
assertTrue(chatResult.getResponse().startsWith("NI"));
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
|
||||
assertTrue(chatResult.getResponseModel().getMessage().getContent().startsWith("NI"));
|
||||
assertEquals(3, chatResult.getChatHistory().size());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void testChatWithExplicitToolDefinition() {
|
||||
testEndpointReachability();
|
||||
try {
|
||||
ollamaAPI.setVerbose(true);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||
|
||||
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
||||
.functionName("get-employee-details")
|
||||
.functionDescription("Get employee details from the database")
|
||||
.toolPrompt(
|
||||
Tools.PromptFuncDefinition.builder().type("function").function(
|
||||
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||
.name("get-employee-details")
|
||||
.description("Get employee details from the database")
|
||||
.parameters(
|
||||
Tools.PromptFuncDefinition.Parameters.builder()
|
||||
.type("object")
|
||||
.properties(
|
||||
new Tools.PropsBuilder()
|
||||
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
|
||||
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
|
||||
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
|
||||
.build()
|
||||
)
|
||||
.required(List.of("employee-name"))
|
||||
.build()
|
||||
).build()
|
||||
).build()
|
||||
)
|
||||
.toolFunction(new DBQueryFunction())
|
||||
.build();
|
||||
|
||||
ollamaAPI.registerTool(databaseQueryToolSpecification);
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Give me the ID of the employee named 'Rahul Kumar'?")
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||
assertEquals(1, toolCalls.size());
|
||||
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||
assertEquals("get-employee-details", function.getName());
|
||||
assertEquals(1, function.getArguments().size());
|
||||
Object employeeName = function.getArguments().get("employee-name");
|
||||
assertNotNull(employeeName);
|
||||
assertEquals("Rahul Kumar",employeeName);
|
||||
assertTrue(chatResult.getChatHistory().size()>2);
|
||||
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||
assertNull(finalToolCalls);
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void testChatWithAnnotatedToolsAndSingleParam() {
|
||||
testEndpointReachability();
|
||||
try {
|
||||
ollamaAPI.setVerbose(true);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||
|
||||
ollamaAPI.registerAnnotatedTools();
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Compute the most important constant in the world using 5 digits")
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||
assertEquals(1, toolCalls.size());
|
||||
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||
assertEquals("computeImportantConstant", function.getName());
|
||||
assertEquals(1, function.getArguments().size());
|
||||
Object noOfDigits = function.getArguments().get("noOfDigits");
|
||||
assertNotNull(noOfDigits);
|
||||
assertEquals("5", noOfDigits.toString());
|
||||
assertTrue(chatResult.getChatHistory().size()>2);
|
||||
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||
assertNull(finalToolCalls);
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void testChatWithAnnotatedToolsAndMultipleParams() {
|
||||
testEndpointReachability();
|
||||
try {
|
||||
ollamaAPI.setVerbose(true);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||
|
||||
ollamaAPI.registerAnnotatedTools(new AnnotatedTool());
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Greet Pedro with a lot of hearts and respond to me, " +
|
||||
"and state how many emojis have been in your greeting")
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||
assertEquals(1, toolCalls.size());
|
||||
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||
assertEquals("sayHello", function.getName());
|
||||
assertEquals(2, function.getArguments().size());
|
||||
Object name = function.getArguments().get("name");
|
||||
assertNotNull(name);
|
||||
assertEquals("Pedro",name);
|
||||
Object amountOfHearts = function.getArguments().get("amountOfHearts");
|
||||
assertNotNull(amountOfHearts);
|
||||
assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1);
|
||||
assertTrue(chatResult.getChatHistory().size()>2);
|
||||
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||
assertNull(finalToolCalls);
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void testChatWithToolsAndStream() {
|
||||
testEndpointReachability();
|
||||
try {
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
||||
.functionName("get-employee-details")
|
||||
.functionDescription("Get employee details from the database")
|
||||
.toolPrompt(
|
||||
Tools.PromptFuncDefinition.builder().type("function").function(
|
||||
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||
.name("get-employee-details")
|
||||
.description("Get employee details from the database")
|
||||
.parameters(
|
||||
Tools.PromptFuncDefinition.Parameters.builder()
|
||||
.type("object")
|
||||
.properties(
|
||||
new Tools.PropsBuilder()
|
||||
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
|
||||
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
|
||||
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
|
||||
.build()
|
||||
)
|
||||
.required(List.of("employee-name"))
|
||||
.build()
|
||||
).build()
|
||||
).build()
|
||||
)
|
||||
.toolFunction(new DBQueryFunction())
|
||||
.build();
|
||||
|
||||
ollamaAPI.registerTool(databaseQueryToolSpecification);
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Give me the ID of the employee named 'Rahul Kumar'?")
|
||||
.build();
|
||||
|
||||
StringBuffer sb = new StringBuffer();
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
|
||||
LOG.info(s);
|
||||
String substring = s.substring(sb.toString().length());
|
||||
LOG.info(substring);
|
||||
sb.append(substring);
|
||||
});
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
|
||||
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void testChatWithStream() {
|
||||
@@ -232,7 +446,10 @@ class TestRealAPIs {
|
||||
sb.append(substring);
|
||||
});
|
||||
assertNotNull(chatResult);
|
||||
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
|
||||
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
fail(e);
|
||||
}
|
||||
@@ -246,12 +463,12 @@ class TestRealAPIs {
|
||||
OllamaChatRequestBuilder builder =
|
||||
OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
||||
OllamaChatRequest requestModel =
|
||||
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
||||
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
|
||||
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
|
||||
|
||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponse());
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
|
||||
builder.reset();
|
||||
|
||||
@@ -261,7 +478,7 @@ class TestRealAPIs {
|
||||
|
||||
chatResult = ollamaAPI.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponse());
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
|
||||
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
@@ -275,7 +492,7 @@ class TestRealAPIs {
|
||||
testEndpointReachability();
|
||||
try {
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
|
||||
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
|
||||
.build();
|
||||
|
||||
@@ -368,6 +585,14 @@ class TestRealAPIs {
|
||||
}
|
||||
}
|
||||
|
||||
class DBQueryFunction implements ToolFunction {
|
||||
@Override
|
||||
public Object apply(Map<String, Object> arguments) {
|
||||
// perform DB operations here
|
||||
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone"));
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
class Config {
|
||||
private String ollamaURL;
|
||||
@@ -392,4 +617,6 @@ class Config {
|
||||
throw new RuntimeException("Error loading properties", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
21
src/test/java/io/github/ollama4j/samples/AnnotatedTool.java
Normal file
21
src/test/java/io/github/ollama4j/samples/AnnotatedTool.java
Normal file
@@ -0,0 +1,21 @@
|
||||
package io.github.ollama4j.samples;
|
||||
|
||||
import io.github.ollama4j.tools.annotations.ToolProperty;
|
||||
import io.github.ollama4j.tools.annotations.ToolSpec;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
public class AnnotatedTool {
|
||||
|
||||
@ToolSpec(desc = "Computes the most important constant all around the globe!")
|
||||
public String computeImportantConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){
|
||||
return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString();
|
||||
}
|
||||
|
||||
@ToolSpec(desc = "Says hello to a friend!")
|
||||
public String sayHello(@ToolProperty(name = "name",desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name="amountOfHearts",desc = "amount of heart emojis that should be used", required = false) Integer amountOfHearts) {
|
||||
String hearts = amountOfHearts!=null ? "♡".repeat(amountOfHearts) : "";
|
||||
return "Hello " + name +" ("+someRandomProperty+") " + hearts;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -2,6 +2,10 @@ package io.github.ollama4j.unittests;
|
||||
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
import io.github.ollama4j.exceptions.RoleNotFoundException;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
|
||||
import io.github.ollama4j.models.response.ModelDetail;
|
||||
import io.github.ollama4j.models.response.OllamaAsyncResultStreamer;
|
||||
import io.github.ollama4j.models.response.OllamaResult;
|
||||
@@ -14,7 +18,9 @@ import java.io.IOException;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class TestMockedAPIs {
|
||||
@@ -97,6 +103,34 @@ class TestMockedAPIs {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testEmbed() {
|
||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||
String model = OllamaModelType.LLAMA2;
|
||||
List<String> inputs = List.of("some prompt text");
|
||||
try {
|
||||
when(ollamaAPI.embed(model, inputs)).thenReturn(new OllamaEmbedResponseModel());
|
||||
ollamaAPI.embed(model, inputs);
|
||||
verify(ollamaAPI, times(1)).embed(model, inputs);
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testEmbedWithEmbedRequestModel() {
|
||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||
String model = OllamaModelType.LLAMA2;
|
||||
List<String> inputs = List.of("some prompt text");
|
||||
try {
|
||||
when(ollamaAPI.embed(new OllamaEmbedRequestModel(model, inputs))).thenReturn(new OllamaEmbedResponseModel());
|
||||
ollamaAPI.embed(new OllamaEmbedRequestModel(model, inputs));
|
||||
verify(ollamaAPI, times(1)).embed(new OllamaEmbedRequestModel(model, inputs));
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAsk() {
|
||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||
@@ -161,4 +195,68 @@ class TestMockedAPIs {
|
||||
ollamaAPI.generateAsync(model, prompt, false);
|
||||
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAddCustomRole() {
|
||||
OllamaAPI ollamaAPI = mock(OllamaAPI.class);
|
||||
String roleName = "custom-role";
|
||||
OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
|
||||
when(ollamaAPI.addCustomRole(roleName)).thenReturn(expectedRole);
|
||||
OllamaChatMessageRole customRole = ollamaAPI.addCustomRole(roleName);
|
||||
assertEquals(expectedRole, customRole);
|
||||
verify(ollamaAPI, times(1)).addCustomRole(roleName);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testListRoles() {
|
||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||
OllamaChatMessageRole role1 = OllamaChatMessageRole.newCustomRole("role1");
|
||||
OllamaChatMessageRole role2 = OllamaChatMessageRole.newCustomRole("role2");
|
||||
List<OllamaChatMessageRole> expectedRoles = List.of(role1, role2);
|
||||
when(ollamaAPI.listRoles()).thenReturn(expectedRoles);
|
||||
List<OllamaChatMessageRole> actualRoles = ollamaAPI.listRoles();
|
||||
assertEquals(expectedRoles, actualRoles);
|
||||
verify(ollamaAPI, times(1)).listRoles();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetRoleNotFound() {
|
||||
OllamaAPI ollamaAPI = mock(OllamaAPI.class);
|
||||
String roleName = "non-existing-role";
|
||||
try {
|
||||
when(ollamaAPI.getRole(roleName)).thenThrow(new RoleNotFoundException("Role not found"));
|
||||
} catch (RoleNotFoundException exception) {
|
||||
throw new RuntimeException("Failed to run test: testGetRoleNotFound");
|
||||
}
|
||||
try {
|
||||
ollamaAPI.getRole(roleName);
|
||||
fail("Expected RoleNotFoundException not thrown");
|
||||
} catch (RoleNotFoundException exception) {
|
||||
assertEquals("Role not found", exception.getMessage());
|
||||
}
|
||||
try {
|
||||
verify(ollamaAPI, times(1)).getRole(roleName);
|
||||
} catch (RoleNotFoundException exception) {
|
||||
throw new RuntimeException("Failed to run test: testGetRoleNotFound");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetRoleFound() {
|
||||
OllamaAPI ollamaAPI = mock(OllamaAPI.class);
|
||||
String roleName = "existing-role";
|
||||
OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
|
||||
try {
|
||||
when(ollamaAPI.getRole(roleName)).thenReturn(expectedRole);
|
||||
} catch (RoleNotFoundException exception) {
|
||||
throw new RuntimeException("Failed to run test: testGetRoleFound");
|
||||
}
|
||||
try {
|
||||
OllamaChatMessageRole actualRole = ollamaAPI.getRole(roleName);
|
||||
assertEquals(expectedRole, actualRole);
|
||||
verify(ollamaAPI, times(1)).getRole(roleName);
|
||||
} catch (RoleNotFoundException exception) {
|
||||
throw new RuntimeException("Failed to run test: testGetRoleFound");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||
@@ -42,7 +43,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
|
||||
|
||||
@Test
|
||||
public void testRequestWithMessageAndImage() {
|
||||
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
|
||||
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", Collections.emptyList(),
|
||||
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
|
||||
String jsonRequest = serialize(req);
|
||||
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);
|
||||
|
||||
@@ -1,36 +1,37 @@
|
||||
package io.github.ollama4j.unittests.jackson;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestBuilder;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
|
||||
public class TestEmbeddingsRequestSerialization extends AbstractSerializationTest<OllamaEmbeddingsRequestModel> {
|
||||
public class TestEmbedRequestSerialization extends AbstractSerializationTest<OllamaEmbedRequestModel> {
|
||||
|
||||
private OllamaEmbeddingsRequestBuilder builder;
|
||||
private OllamaEmbedRequestBuilder builder;
|
||||
|
||||
@BeforeEach
|
||||
public void init() {
|
||||
builder = OllamaEmbeddingsRequestBuilder.getInstance("DummyModel","DummyPrompt");
|
||||
builder = OllamaEmbedRequestBuilder.getInstance("DummyModel","DummyPrompt");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRequestOnlyMandatoryFields() {
|
||||
OllamaEmbeddingsRequestModel req = builder.build();
|
||||
OllamaEmbedRequestModel req = builder.build();
|
||||
String jsonRequest = serialize(req);
|
||||
assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaEmbeddingsRequestModel.class), req);
|
||||
assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaEmbedRequestModel.class), req);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRequestWithOptions() {
|
||||
OptionsBuilder b = new OptionsBuilder();
|
||||
OllamaEmbeddingsRequestModel req = builder
|
||||
OllamaEmbedRequestModel req = builder
|
||||
.withOptions(b.setMirostat(1).build()).build();
|
||||
|
||||
String jsonRequest = serialize(req);
|
||||
OllamaEmbeddingsRequestModel deserializeRequest = deserialize(jsonRequest,OllamaEmbeddingsRequestModel.class);
|
||||
OllamaEmbedRequestModel deserializeRequest = deserialize(jsonRequest,OllamaEmbedRequestModel.class);
|
||||
assertEqualsAfterUnmarshalling(deserializeRequest, req);
|
||||
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
ollama.url=http://localhost:11434
|
||||
ollama.model=qwen:0.5b
|
||||
ollama.model.image=llava
|
||||
ollama.model=llama3.2:1b
|
||||
ollama.model.image=llava:latest
|
||||
ollama.request-timeout-seconds=120
|
||||
Reference in New Issue
Block a user