Merge pull request #157 from ollama4j/tests

Tests
This commit is contained in:
Amith Koujalgi 2025-09-10 21:03:38 +05:30 committed by GitHub
commit 74fbafeb3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 1178 additions and 151 deletions

10
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1,10 @@
# See https://docs.github.com/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
# Default owners for everything in the repo
* @amithkoujalgi
# Example for scoping ownership (uncomment and adjust as teams evolve)
# /docs/ @amithkoujalgi
# /src/ @amithkoujalgi

59
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@ -0,0 +1,59 @@
name: Bug report
description: File a bug report
labels: [bug]
assignees: []
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report!
- type: input
id: version
attributes:
label: ollama4j version
description: e.g., 1.1.0
placeholder: 1.1.0
validations:
required: true
- type: input
id: java
attributes:
label: Java version
description: Output of `java -version`
placeholder: 11/17/21
validations:
required: true
- type: input
id: environment
attributes:
label: Environment
description: OS, build tool, Docker/Testcontainers, etc.
placeholder: macOS 13, Maven 3.9.x, Docker 24.x
- type: textarea
id: what-happened
attributes:
label: What happened?
description: Also tell us what you expected to happen
validations:
required: true
- type: textarea
id: steps
attributes:
label: Steps to reproduce
description: Be as specific as possible
placeholder: |
1. Setup ...
2. Run ...
3. Observe ...
validations:
required: true
- type: textarea
id: logs
attributes:
label: Relevant logs/stack traces
render: shell
- type: textarea
id: additional
attributes:
label: Additional context

6
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@ -0,0 +1,6 @@
blank_issues_enabled: false
contact_links:
- name: Questions / Discussions
url: https://github.com/ollama4j/ollama4j/discussions
about: Ask questions and discuss ideas here

View File

@ -0,0 +1,31 @@
name: Feature request
description: Suggest an idea or enhancement
labels: [enhancement]
assignees: []
body:
- type: markdown
attributes:
value: |
Thanks for suggesting an improvement!
- type: textarea
id: problem
attributes:
label: Is your feature request related to a problem?
description: A clear and concise description of the problem
placeholder: I'm frustrated when...
- type: textarea
id: solution
attributes:
label: Describe the solution you'd like
placeholder: I'd like...
validations:
required: true
- type: textarea
id: alternatives
attributes:
label: Describe alternatives you've considered
- type: textarea
id: context
attributes:
label: Additional context

34
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,34 @@
## Description
Describe what this PR does and why.
## Type of change
- [ ] feat: New feature
- [ ] fix: Bug fix
- [ ] docs: Documentation update
- [ ] refactor: Refactoring
- [ ] test: Tests only
- [ ] build/ci: Build or CI changes
## How has this been tested?
Explain the testing done. Include commands, screenshots, logs.
## Checklist
- [ ] I ran `pre-commit run -a` locally
- [ ] `make build` succeeds locally
- [ ] Unit/integration tests added or updated as needed
- [ ] Docs updated (README/docs site) if user-facing changes
- [ ] PR title follows Conventional Commits
## Breaking changes
List any breaking changes and migration notes.
## Related issues
Fixes #

View File

@ -1,11 +1,34 @@
# To get started with Dependabot version updates, you'll need to specify which # To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located. ## package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options: ## Please see the documentation for all configuration options:
# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file ## https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
#
#version: 2
#updates:
# - package-ecosystem: "" # See documentation for possible values
# directory: "/" # Location of package manifests
# schedule:
# interval: "weekly"
version: 2 version: 2
updates: updates:
- package-ecosystem: "" # See documentation for possible values - package-ecosystem: "maven"
directory: "/" # Location of package manifests directory: "/"
schedule: schedule:
interval: "weekly" interval: "weekly"
open-pull-requests-limit: 5
labels: ["dependencies"]
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 5
labels: ["dependencies"]
- package-ecosystem: "npm"
directory: "/docs"
schedule:
interval: "weekly"
open-pull-requests-limit: 5
labels: ["dependencies"]
#

44
.github/workflows/codeql.yml vendored Normal file
View File

@ -0,0 +1,44 @@
name: CodeQL
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
schedule:
- cron: '0 3 * * 1'
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ 'java', 'javascript' ]
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up JDK
if: matrix.language == 'java'
uses: actions/setup-java@v4
with:
distribution: temurin
java-version: '11'
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
- name: Autobuild
uses: github/codeql-action/autobuild@v3
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3

30
.github/workflows/pre-commit.yml vendored Normal file
View File

@ -0,0 +1,30 @@
name: Pre-commit Check on PR
on:
pull_request:
types: [opened, reopened, synchronize]
branches:
- main
#on:
# pull_request:
# branches: [ main ]
# push:
# branches: [ main ]
jobs:
run:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install pre-commit
run: |
python -m pip install --upgrade pip
pip install pre-commit
# - name: Run pre-commit
# run: |
# pre-commit run --all-files --show-diff-on-failure

33
.github/workflows/stale.yml vendored Normal file
View File

@ -0,0 +1,33 @@
name: Mark stale issues and PRs
on:
schedule:
- cron: '0 2 * * *'
permissions:
issues: write
pull-requests: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v9
with:
days-before-stale: 60
days-before-close: 14
stale-issue-label: 'stale'
stale-pr-label: 'stale'
exempt-issue-labels: 'pinned,security'
exempt-pr-labels: 'pinned,security'
stale-issue-message: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs.
close-issue-message: >
Closing this stale issue. Feel free to reopen if this is still relevant.
stale-pr-message: >
This pull request has been automatically marked as stale due to inactivity.
It will be closed if no further activity occurs.
close-pr-message: >
Closing this stale pull request. Please reopen when you're ready to continue.

125
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,125 @@
## Contributing to Ollama4j
Thanks for your interest in contributing! This guide explains how to set up your environment, make changes, and submit pull requests.
### Code of Conduct
By participating, you agree to abide by our [Code of Conduct](CODE_OF_CONDUCT.md).
### Quick Start
Prerequisites:
- Java 11+
- Maven 3.8+
- Docker (required for integration tests)
- Make (for convenience targets)
- pre-commit (for Git hooks)
Setup:
```bash
# 1) Fork the repo and clone your fork
git clone https://github.com/<your-username>/ollama4j.git
cd ollama4j
# 2) Install and enable git hooks
pre-commit install --hook-type pre-commit --hook-type commit-msg
# 3) Prepare dev environment (installs husk deps/tools if needed)
make dev
```
Build and test:
```bash
# Build
make build
# Run unit tests
make unit-tests
# Run integration tests (requires Docker running)
make integration-tests
```
If you prefer raw Maven:
```bash
# Unit tests profile
mvn -P unit-tests clean test
# Integration tests profile (Docker required)
mvn -P integration-tests -DskipUnitTests=true clean verify
```
### Commit Style
We use Conventional Commits. Commit messages and PR titles should follow:
```
<type>(optional scope): <short summary>
[optional body]
[optional footer(s)]
```
Common types: `feat`, `fix`, `docs`, `refactor`, `test`, `build`, `chore`.
Commit message formatting is enforced via `commitizen` through `pre-commit` hooks.
### Pre-commit Hooks
Before pushing, run:
```bash
pre-commit run -a
```
Hooks will check for merge conflicts, large files, YAML/XML/JSON validity, line endings, and basic formatting. Fix reported issues before opening a PR.
### Coding Guidelines
- Target Java 11+; match existing style and formatting.
- Prefer clear, descriptive names over abbreviations.
- Add Javadoc for public APIs and non-obvious logic.
- Include meaningful tests for new features and bug fixes.
- Avoid introducing new dependencies without discussion.
### Tests
- Unit tests: place under `src/test/java/**/unittests/`.
- Integration tests: place under `src/test/java/**/integrationtests/` (uses Testcontainers; ensure Docker is running).
### Documentation
- Update `README.md`, Javadoc, and `docs/` when you change public APIs or user-facing behavior.
- Add example snippets where useful. Keep API references consistent with the website content when applicable.
### Pull Requests
Before opening a PR:
- Ensure `make build` and all tests pass locally.
- Run `pre-commit run -a` and fix any issues.
- Keep PRs focused and reasonably small. Link related issues (e.g., "Closes #123").
- Describe the change, rationale, and any trade-offs in the PR description.
Review process:
- Maintainers will review for correctness, scope, tests, and docs.
- You may be asked to iterate; please be responsive to comments.
### Security
If you discover a security issue, please do not open a public issue. Instead, email the maintainer at `koujalgi.amith@gmail.com` with details.
### License
By contributing, you agree that your contributions will be licensed under the projects [MIT License](LICENSE).
### Questions and Discussion
Have questions or ideas? Open a GitHub Discussion or issue. We welcome feedback and proposals!

39
SECURITY.md Normal file
View File

@ -0,0 +1,39 @@
## Security Policy
### Supported Versions
We aim to support the latest released version of `ollama4j` and the most recent minor version prior to it. Older versions may receive fixes on a best-effort basis.
### Reporting a Vulnerability
Please do not open public GitHub issues for security vulnerabilities.
Instead, email the maintainer at:
```
koujalgi.amith@gmail.com
```
Include as much detail as possible:
- A clear description of the issue and impact
- Steps to reproduce or proof-of-concept
- Affected version(s) and environment
- Any suggested mitigations or patches
You should receive an acknowledgement within 72 hours. We will work with you to validate the issue, determine severity, and prepare a fix.
### Disclosure
We follow a responsible disclosure process:
1. Receive and validate report privately.
2. Develop and test a fix.
3. Coordinate a release that includes the fix.
4. Publicly credit the reporter (if desired) in release notes.
### GPG Signatures
Releases may be signed as part of our CI pipeline. If verification fails or you have concerns about release integrity, please contact us via the email above.

View File

@ -373,7 +373,6 @@ public class CouchbaseToolCallingExample {
String modelName = Utilities.getFromConfig("tools_model_mistral"); String modelName = Utilities.getFromConfig("tools_model_mistral");
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.setVerbose(false);
ollamaAPI.setRequestTimeoutSeconds(60); ollamaAPI.setRequestTimeoutSeconds(60);
Tools.ToolSpecification callSignFinderToolSpec = getCallSignFinderToolSpec(cluster, bucketName); Tools.ToolSpecification callSignFinderToolSpec = getCallSignFinderToolSpec(cluster, bucketName);

View File

@ -1,25 +0,0 @@
---
sidebar_position: 1
---
# Set Verbosity
This API lets you set the verbosity of the Ollama client.
## Try asking a question about the model.
```java
import io.github.ollama4j.OllamaAPI;
public class Main {
public static void main(String[] args) {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.setVerbose(true);
}
}
```

View File

@ -139,8 +139,6 @@ public class OllamaAPITest {
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.setVerbose(true);
boolean isOllamaServerReachable = ollamaAPI.ping(); boolean isOllamaServerReachable = ollamaAPI.ping();
System.out.println("Is Ollama server running: " + isOllamaServerReachable); System.out.println("Is Ollama server running: " + isOllamaServerReachable);

View File

@ -1,7 +1,6 @@
package io.github.ollama4j; package io.github.ollama4j;
import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.exceptions.RoleNotFoundException; import io.github.ollama4j.exceptions.RoleNotFoundException;
@ -55,7 +54,7 @@ import java.util.stream.Collectors;
@SuppressWarnings({"DuplicatedCode", "resource"}) @SuppressWarnings({"DuplicatedCode", "resource"})
public class OllamaAPI { public class OllamaAPI {
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
private final String host; private final String host;
private Auth auth; private Auth auth;
@ -71,16 +70,6 @@ public class OllamaAPI {
@Setter @Setter
private long requestTimeoutSeconds = 10; private long requestTimeoutSeconds = 10;
/**
* Enables or disables verbose logging of responses.
* <p>
* If set to {@code true}, the API will log detailed information about requests
* and responses.
* Default is {@code true}.
*/
@Setter
private boolean verbose = true;
/** /**
* The maximum number of retries for tool calls during chat interactions. * The maximum number of retries for tool calls during chat interactions.
* <p> * <p>
@ -123,9 +112,7 @@ public class OllamaAPI {
} else { } else {
this.host = host; this.host = host;
} }
if (this.verbose) { LOG.info("Ollama API initialized with host: {}", this.host);
logger.info("Ollama API initialized with host: {}", this.host);
}
} }
/** /**
@ -463,7 +450,7 @@ public class OllamaAPI {
int attempt = currentRetry + 1; int attempt = currentRetry + 1;
if (attempt < maxRetries) { if (attempt < maxRetries) {
long backoffMillis = baseDelayMillis * (1L << currentRetry); long backoffMillis = baseDelayMillis * (1L << currentRetry);
logger.error("Failed to pull model {}, retrying in {}s... (attempt {}/{})", LOG.error("Failed to pull model {}, retrying in {}s... (attempt {}/{})",
modelName, backoffMillis / 1000, attempt, maxRetries); modelName, backoffMillis / 1000, attempt, maxRetries);
try { try {
Thread.sleep(backoffMillis); Thread.sleep(backoffMillis);
@ -472,7 +459,7 @@ public class OllamaAPI {
throw ie; throw ie;
} }
} else { } else {
logger.error("Failed to pull model {} after {} attempts, no more retries.", modelName, maxRetries); LOG.error("Failed to pull model {} after {} attempts, no more retries.", modelName, maxRetries);
} }
} }
@ -502,21 +489,19 @@ public class OllamaAPI {
} }
if (modelPullResponse.getStatus() != null) { if (modelPullResponse.getStatus() != null) {
if (verbose) { LOG.info("{}: {}", modelName, modelPullResponse.getStatus());
logger.info("{}: {}", modelName, modelPullResponse.getStatus());
}
// Check if status is "success" and set success flag to true. // Check if status is "success" and set success flag to true.
if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) { if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) {
success = true; success = true;
} }
} }
} else { } else {
logger.error("Received null response for model pull."); LOG.error("Received null response for model pull.");
} }
} }
} }
if (!success) { if (!success) {
logger.error("Model pull failed or returned invalid status."); LOG.error("Model pull failed or returned invalid status.");
throw new OllamaBaseException("Model pull failed or returned invalid status."); throw new OllamaBaseException("Model pull failed or returned invalid status.");
} }
if (statusCode != 200) { if (statusCode != 200) {
@ -625,9 +610,7 @@ public class OllamaAPI {
if (responseString.contains("error")) { if (responseString.contains("error")) {
throw new OllamaBaseException(responseString); throw new OllamaBaseException(responseString);
} }
if (verbose) { LOG.debug(responseString);
logger.info(responseString);
}
} }
/** /**
@ -663,9 +646,7 @@ public class OllamaAPI {
if (responseString.contains("error")) { if (responseString.contains("error")) {
throw new OllamaBaseException(responseString); throw new OllamaBaseException(responseString);
} }
if (verbose) { LOG.debug(responseString);
logger.info(responseString);
}
} }
/** /**
@ -697,9 +678,7 @@ public class OllamaAPI {
if (responseString.contains("error")) { if (responseString.contains("error")) {
throw new OllamaBaseException(responseString); throw new OllamaBaseException(responseString);
} }
if (verbose) { LOG.debug(responseString);
logger.info(responseString);
}
} }
/** /**
@ -967,15 +946,14 @@ public class OllamaAPI {
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
if (verbose) {
try { try {
String prettyJson = Utils.getObjectMapper().writerWithDefaultPrettyPrinter() String prettyJson = Utils.getObjectMapper().writerWithDefaultPrettyPrinter()
.writeValueAsString(Utils.getObjectMapper().readValue(jsonData, Object.class)); .writeValueAsString(Utils.getObjectMapper().readValue(jsonData, Object.class));
logger.info("Asking model:\n{}", prettyJson); LOG.debug("Asking model:\n{}", prettyJson);
} catch (Exception e) { } catch (Exception e) {
logger.info("Asking model: {}", jsonData); LOG.debug("Asking model: {}", jsonData);
}
} }
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseBody = response.body(); String responseBody = response.body();
@ -996,15 +974,11 @@ public class OllamaAPI {
ollamaResult.setPromptEvalDuration(structuredResult.getPromptEvalDuration()); ollamaResult.setPromptEvalDuration(structuredResult.getPromptEvalDuration());
ollamaResult.setEvalCount(structuredResult.getEvalCount()); ollamaResult.setEvalCount(structuredResult.getEvalCount());
ollamaResult.setEvalDuration(structuredResult.getEvalDuration()); ollamaResult.setEvalDuration(structuredResult.getEvalDuration());
if (verbose) { LOG.debug("Model response:\n{}", ollamaResult);
logger.info("Model response:\n{}", ollamaResult);
}
return ollamaResult; return ollamaResult;
} else { } else {
if (verbose) { LOG.debug("Model response:\n{}",
logger.info("Model response:\n{}",
Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseBody)); Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseBody));
}
throw new OllamaBaseException(statusCode + " - " + responseBody); throw new OllamaBaseException(statusCode + " - " + responseBody);
} }
} }
@ -1055,9 +1029,9 @@ public class OllamaAPI {
if (!toolsResponse.isEmpty()) { if (!toolsResponse.isEmpty()) {
try { try {
// Try to parse the string to see if it's a valid JSON // Try to parse the string to see if it's a valid JSON
JsonNode jsonNode = objectMapper.readTree(toolsResponse); objectMapper.readTree(toolsResponse);
} catch (JsonParseException e) { } catch (JsonParseException e) {
logger.warn("Response from model does not contain any tool calls. Returning the response as is."); LOG.warn("Response from model does not contain any tool calls. Returning the response as is.");
return toolResult; return toolResult;
} }
toolFunctionCallSpecs = objectMapper.readValue(toolsResponse, toolFunctionCallSpecs = objectMapper.readValue(toolsResponse,
@ -1361,8 +1335,7 @@ public class OllamaAPI {
*/ */
public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds);
verbose);
OllamaChatResult result; OllamaChatResult result;
// add all registered tools to Request // add all registered tools to Request
@ -1417,9 +1390,7 @@ public class OllamaAPI {
*/ */
public void registerTool(Tools.ToolSpecification toolSpecification) { public void registerTool(Tools.ToolSpecification toolSpecification) {
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
if (this.verbose) { LOG.debug("Registered tool: {}", toolSpecification.getFunctionName());
logger.debug("Registered tool: {}", toolSpecification.getFunctionName());
}
} }
/** /**
@ -1444,9 +1415,7 @@ public class OllamaAPI {
*/ */
public void deregisterTools() { public void deregisterTools() {
toolRegistry.clear(); toolRegistry.clear();
if (this.verbose) { LOG.debug("All tools have been deregistered.");
logger.debug("All tools have been deregistered.");
}
} }
/** /**
@ -1621,8 +1590,7 @@ public class OllamaAPI {
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel,
OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds);
verbose);
OllamaResult result; OllamaResult result;
if (responseStreamHandler != null) { if (responseStreamHandler != null) {
ollamaRequestModel.setStream(true); ollamaRequestModel.setStream(true);
@ -1663,9 +1631,7 @@ public class OllamaAPI {
String methodName = toolFunctionCallSpec.getName(); String methodName = toolFunctionCallSpec.getName();
Map<String, Object> arguments = toolFunctionCallSpec.getArguments(); Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
ToolFunction function = toolRegistry.getToolFunction(methodName); ToolFunction function = toolRegistry.getToolFunction(methodName);
if (verbose) { LOG.debug("Invoking function {} with arguments {}", methodName, arguments);
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
}
if (function == null) { if (function == null) {
throw new ToolNotFoundException( throw new ToolNotFoundException(
"No such tool: " + methodName + ". Please register the tool before invoking it."); "No such tool: " + methodName + ". Please register the tool before invoking it.");

View File

@ -1,10 +1,14 @@
package io.github.ollama4j.impl; package io.github.ollama4j.impl;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaStreamHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ConsoleOutputStreamHandler implements OllamaStreamHandler { public class ConsoleOutputStreamHandler implements OllamaStreamHandler {
private static final Logger LOG = LoggerFactory.getLogger(ConsoleOutputStreamHandler.class);
@Override @Override
public void accept(String message) { public void accept(String message) {
System.out.print(message); LOG.info(message);
} }
} }

View File

@ -31,8 +31,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
private OllamaTokenHandler tokenHandler; private OllamaTokenHandler tokenHandler;
public OllamaChatEndpointCaller(String host, Auth auth, long requestTimeoutSeconds, boolean verbose) { public OllamaChatEndpointCaller(String host, Auth auth, long requestTimeoutSeconds) {
super(host, auth, requestTimeoutSeconds, verbose); super(host, auth, requestTimeoutSeconds);
} }
@Override @Override
@ -91,7 +91,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
.POST( .POST(
body.getBodyPublisher()); body.getBodyPublisher());
HttpRequest request = requestBuilder.build(); HttpRequest request = requestBuilder.build();
if (isVerbose()) LOG.info("Asking model: {}", body); LOG.debug("Asking model: {}", body);
HttpResponse<InputStream> response = HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
@ -150,7 +150,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
} }
OllamaChatResult ollamaResult = OllamaChatResult ollamaResult =
new OllamaChatResult(ollamaChatResponseModel, body.getMessages()); new OllamaChatResult(ollamaChatResponseModel, body.getMessages());
if (isVerbose()) LOG.info("Model response: " + ollamaResult); LOG.debug("Model response: {}", ollamaResult);
return ollamaResult; return ollamaResult;
} }
} }

View File

@ -1,10 +1,7 @@
package io.github.ollama4j.models.request; package io.github.ollama4j.models.request;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.utils.Constants; import io.github.ollama4j.utils.Constants;
import lombok.Getter; import lombok.Getter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.URI; import java.net.URI;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
@ -16,18 +13,14 @@ import java.time.Duration;
@Getter @Getter
public abstract class OllamaEndpointCaller { public abstract class OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
private final String host; private final String host;
private final Auth auth; private final Auth auth;
private final long requestTimeoutSeconds; private final long requestTimeoutSeconds;
private final boolean verbose;
public OllamaEndpointCaller(String host, Auth auth, long requestTimeoutSeconds, boolean verbose) { public OllamaEndpointCaller(String host, Auth auth, long requestTimeoutSeconds) {
this.host = host; this.host = host;
this.auth = auth; this.auth = auth;
this.requestTimeoutSeconds = requestTimeoutSeconds; this.requestTimeoutSeconds = requestTimeoutSeconds;
this.verbose = verbose;
} }
protected abstract String getEndpointSuffix(); protected abstract String getEndpointSuffix();

View File

@ -29,8 +29,8 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
private OllamaGenerateStreamObserver responseStreamObserver; private OllamaGenerateStreamObserver responseStreamObserver;
public OllamaGenerateEndpointCaller(String host, Auth basicAuth, long requestTimeoutSeconds, boolean verbose) { public OllamaGenerateEndpointCaller(String host, Auth basicAuth, long requestTimeoutSeconds) {
super(host, basicAuth, requestTimeoutSeconds, verbose); super(host, basicAuth, requestTimeoutSeconds);
} }
@Override @Override
@ -80,7 +80,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
URI uri = URI.create(getHost() + getEndpointSuffix()); URI uri = URI.create(getHost() + getEndpointSuffix());
HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).POST(body.getBodyPublisher()); HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).POST(body.getBodyPublisher());
HttpRequest request = requestBuilder.build(); HttpRequest request = requestBuilder.build();
if (isVerbose()) LOG.info("Asking model: {}", body); LOG.debug("Asking model: {}", body);
HttpResponse<InputStream> response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); HttpResponse<InputStream> response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -132,7 +132,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
ollamaResult.setEvalCount(ollamaGenerateResponseModel.getEvalCount()); ollamaResult.setEvalCount(ollamaGenerateResponseModel.getEvalCount());
ollamaResult.setEvalDuration(ollamaGenerateResponseModel.getEvalDuration()); ollamaResult.setEvalDuration(ollamaGenerateResponseModel.getEvalDuration());
if (isVerbose()) LOG.info("Model response: {}", ollamaResult); LOG.debug("Model response: {}", ollamaResult);
return ollamaResult; return ollamaResult;
} }
} }

View File

@ -75,7 +75,6 @@ class OllamaAPIIntegrationTest {
api = new OllamaAPI("http://" + ollama.getHost() + ":" + ollama.getMappedPort(internalPort)); api = new OllamaAPI("http://" + ollama.getHost() + ":" + ollama.getMappedPort(internalPort));
} }
api.setRequestTimeoutSeconds(120); api.setRequestTimeoutSeconds(120);
api.setVerbose(true);
api.setNumberOfRetriesForModelPull(5); api.setNumberOfRetriesForModelPull(5);
} }

View File

@ -61,7 +61,6 @@ public class WithAuth {
api = new OllamaAPI("http://" + nginx.getHost() + ":" + nginx.getMappedPort(NGINX_PORT)); api = new OllamaAPI("http://" + nginx.getHost() + ":" + nginx.getMappedPort(NGINX_PORT));
api.setRequestTimeoutSeconds(120); api.setRequestTimeoutSeconds(120);
api.setVerbose(true);
api.setNumberOfRetriesForModelPull(3); api.setNumberOfRetriesForModelPull(3);
String ollamaUrl = "http://" + ollama.getHost() + ":" + ollama.getMappedPort(OLLAMA_INTERNAL_PORT); String ollamaUrl = "http://" + ollama.getHost() + ":" + ollama.getMappedPort(OLLAMA_INTERNAL_PORT);

View File

@ -0,0 +1,50 @@
package io.github.ollama4j.unittests;
import io.github.ollama4j.tools.annotations.OllamaToolService;
import io.github.ollama4j.tools.annotations.ToolProperty;
import io.github.ollama4j.tools.annotations.ToolSpec;
import org.junit.jupiter.api.Test;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import static org.junit.jupiter.api.Assertions.*;
class TestAnnotations {
@OllamaToolService(providers = {SampleProvider.class})
static class SampleToolService {
}
static class SampleProvider {
@ToolSpec(name = "sum", desc = "adds two numbers")
public int sum(@ToolProperty(name = "a", desc = "first addend") int a,
@ToolProperty(name = "b", desc = "second addend", required = false) int b) {
return a + b;
}
}
@Test
void testOllamaToolServiceProvidersPresent() throws Exception {
OllamaToolService ann = SampleToolService.class.getAnnotation(OllamaToolService.class);
assertNotNull(ann);
assertArrayEquals(new Class<?>[]{SampleProvider.class}, ann.providers());
}
@Test
void testToolPropertyMetadataOnParameters() throws Exception {
Method m = SampleProvider.class.getDeclaredMethod("sum", int.class, int.class);
Parameter[] params = m.getParameters();
ToolProperty p0 = params[0].getAnnotation(ToolProperty.class);
ToolProperty p1 = params[1].getAnnotation(ToolProperty.class);
assertNotNull(p0);
assertEquals("a", p0.name());
assertEquals("first addend", p0.desc());
assertTrue(p0.required());
assertNotNull(p1);
assertEquals("b", p1.name());
assertEquals("second addend", p1.desc());
assertFalse(p1.required());
}
}

View File

@ -0,0 +1,26 @@
package io.github.ollama4j.unittests;
import io.github.ollama4j.models.request.BasicAuth;
import io.github.ollama4j.models.request.BearerAuth;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class TestAuth {
@Test
void testBasicAuthHeaderEncoding() {
BasicAuth auth = new BasicAuth("alice", "s3cr3t");
String header = auth.getAuthHeaderValue();
assertTrue(header.startsWith("Basic "));
// "alice:s3cr3t" base64 is "YWxpY2U6czNjcjN0"
assertEquals("Basic YWxpY2U6czNjcjN0", header);
}
@Test
void testBearerAuthHeaderFormat() {
BearerAuth auth = new BearerAuth("abc.def.ghi");
String header = auth.getAuthHeaderValue();
assertEquals("Bearer abc.def.ghi", header);
}
}

View File

@ -0,0 +1,43 @@
package io.github.ollama4j.unittests;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.github.ollama4j.utils.BooleanToJsonFormatFlagSerializer;
import io.github.ollama4j.utils.Utils;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
class TestBooleanToJsonFormatFlagSerializer {
static class Holder {
@JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class)
public Boolean formatJson;
}
@Test
void testSerializeTrueWritesJsonString() throws JsonProcessingException {
ObjectMapper mapper = Utils.getObjectMapper().copy();
mapper.setSerializationInclusion(JsonInclude.Include.NON_EMPTY);
Holder holder = new Holder();
holder.formatJson = true;
String json = mapper.writeValueAsString(holder);
assertEquals("{\"formatJson\":\"json\"}", json);
}
@Test
void testSerializeFalseOmittedByIsEmpty() throws JsonProcessingException {
ObjectMapper mapper = Utils.getObjectMapper().copy();
mapper.setSerializationInclusion(JsonInclude.Include.NON_EMPTY);
Holder holder = new Holder();
holder.formatJson = false;
String json = mapper.writeValueAsString(holder);
assertEquals("{}", json);
}
}

View File

@ -0,0 +1,32 @@
package io.github.ollama4j.unittests;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.github.ollama4j.utils.FileToBase64Serializer;
import io.github.ollama4j.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestFileToBase64Serializer {
static class Holder {
@JsonSerialize(using = FileToBase64Serializer.class)
public List<byte[]> images;
}
@Test
public void testSerializeByteArraysToBase64Array() throws JsonProcessingException {
ObjectMapper mapper = Utils.getObjectMapper();
Holder holder = new Holder();
holder.images = List.of("hello".getBytes(), "world".getBytes());
String json = mapper.writeValueAsString(holder);
// Base64 of "hello" = aGVsbG8=, of "world" = d29ybGQ=
assertEquals("{\"images\":[\"aGVsbG8=\",\"d29ybGQ=\"]}", json);
}
}

View File

@ -0,0 +1,22 @@
package io.github.ollama4j.unittests;
import io.github.ollama4j.models.chat.OllamaChatMessage;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import org.json.JSONObject;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class TestOllamaChatMessage {
@Test
void testToStringProducesJson() {
OllamaChatMessage msg = new OllamaChatMessage(OllamaChatMessageRole.USER, "hello", null, null, null);
String json = msg.toString();
JSONObject obj = new JSONObject(json);
assertEquals("user", obj.getString("role"));
assertEquals("hello", obj.getString("content"));
assertTrue(obj.has("tool_calls"));
// thinking and images may or may not be present depending on null handling, just ensure no exception
}
}

View File

@ -0,0 +1,44 @@
package io.github.ollama4j.unittests;
import io.github.ollama4j.exceptions.RoleNotFoundException;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
class TestOllamaChatMessageRole {
@Test
void testStaticRolesRegistered() throws Exception {
List<OllamaChatMessageRole> roles = OllamaChatMessageRole.getRoles();
assertTrue(roles.contains(OllamaChatMessageRole.SYSTEM));
assertTrue(roles.contains(OllamaChatMessageRole.USER));
assertTrue(roles.contains(OllamaChatMessageRole.ASSISTANT));
assertTrue(roles.contains(OllamaChatMessageRole.TOOL));
assertEquals("system", OllamaChatMessageRole.SYSTEM.toString());
assertEquals("user", OllamaChatMessageRole.USER.toString());
assertEquals("assistant", OllamaChatMessageRole.ASSISTANT.toString());
assertEquals("tool", OllamaChatMessageRole.TOOL.toString());
assertSame(OllamaChatMessageRole.SYSTEM, OllamaChatMessageRole.getRole("system"));
assertSame(OllamaChatMessageRole.USER, OllamaChatMessageRole.getRole("user"));
assertSame(OllamaChatMessageRole.ASSISTANT, OllamaChatMessageRole.getRole("assistant"));
assertSame(OllamaChatMessageRole.TOOL, OllamaChatMessageRole.getRole("tool"));
}
@Test
void testCustomRoleCreationAndLookup() throws Exception {
OllamaChatMessageRole custom = OllamaChatMessageRole.newCustomRole("myrole");
assertEquals("myrole", custom.toString());
// custom roles are registered globally (per current implementation), so lookup should succeed
assertSame(custom, OllamaChatMessageRole.getRole("myrole"));
}
@Test
void testGetRoleThrowsOnUnknown() {
assertThrows(RoleNotFoundException.class, () -> OllamaChatMessageRole.getRole("does-not-exist"));
}
}

View File

@ -0,0 +1,49 @@
package io.github.ollama4j.unittests;
import io.github.ollama4j.models.chat.OllamaChatMessage;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.chat.OllamaChatRequest;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import org.junit.jupiter.api.Test;
import java.util.Collections;
import static org.junit.jupiter.api.Assertions.*;
class TestOllamaChatRequestBuilder {
@Test
void testResetClearsMessagesButKeepsModelAndThink() {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("my-model")
.withThinking(true)
.withMessage(OllamaChatMessageRole.USER, "first");
OllamaChatRequest beforeReset = builder.build();
assertEquals("my-model", beforeReset.getModel());
assertTrue(beforeReset.isThink());
assertEquals(1, beforeReset.getMessages().size());
builder.reset();
OllamaChatRequest afterReset = builder.build();
assertEquals("my-model", afterReset.getModel());
assertTrue(afterReset.isThink());
assertNotNull(afterReset.getMessages());
assertEquals(0, afterReset.getMessages().size());
}
@Test
void testImageUrlFailuresAreIgnoredAndDoNotBreakBuild() {
// Provide clearly invalid URL, builder logs a warning and continues
OllamaChatRequest req = OllamaChatRequestBuilder.getInstance("m")
.withMessage(OllamaChatMessageRole.USER, "hi", Collections.emptyList(),
"ht!tp://invalid url \n not a uri")
.build();
assertNotNull(req.getMessages());
assertEquals(1, req.getMessages().size());
OllamaChatMessage msg = req.getMessages().get(0);
// images list will be initialized only if any valid URL was added; for invalid URL list can be null
// We just assert that builder didn't crash and message is present with content
assertEquals("hi", msg.getContent());
}
}

View File

@ -0,0 +1,58 @@
package io.github.ollama4j.unittests;
import io.github.ollama4j.utils.OllamaRequestBody;
import io.github.ollama4j.utils.Utils;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.Flow;
import static org.junit.jupiter.api.Assertions.assertEquals;
class TestOllamaRequestBody {
static class SimpleRequest implements OllamaRequestBody {
public String name;
public int value;
SimpleRequest(String name, int value) {
this.name = name;
this.value = value;
}
}
@Test
void testGetBodyPublisherProducesSerializedJson() throws IOException {
SimpleRequest req = new SimpleRequest("abc", 123);
var publisher = req.getBodyPublisher();
StringBuilder data = new StringBuilder();
publisher.subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
subscription.request(Long.MAX_VALUE);
}
@Override
public void onNext(ByteBuffer item) {
data.append(StandardCharsets.UTF_8.decode(item));
}
@Override
public void onError(Throwable throwable) {
}
@Override
public void onComplete() {
}
});
// Trigger the publishing by converting it to a string via the same mapper for determinism
String expected = Utils.getObjectMapper().writeValueAsString(req);
// Due to asynchronous nature, expected content already delivered synchronously by StringPublisher
assertEquals(expected, data.toString());
}
}

View File

@ -0,0 +1,46 @@
package io.github.ollama4j.unittests;
import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.tools.OllamaToolsResult;
import io.github.ollama4j.tools.ToolFunctionCallSpec;
import org.junit.jupiter.api.Test;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
public class TestOllamaToolsResult {
@Test
public void testGetToolResultsTransformsMapToList() {
ToolFunctionCallSpec spec1 = new ToolFunctionCallSpec("fn1", Map.of("a", 1));
ToolFunctionCallSpec spec2 = new ToolFunctionCallSpec("fn2", Map.of("b", 2));
Map<ToolFunctionCallSpec, Object> toolMap = new LinkedHashMap<>();
toolMap.put(spec1, "r1");
toolMap.put(spec2, 123);
OllamaToolsResult tr = new OllamaToolsResult(new OllamaResult("", null, 0L, 200), toolMap);
List<OllamaToolsResult.ToolResult> list = tr.getToolResults();
assertEquals(2, list.size());
assertEquals("fn1", list.get(0).getFunctionName());
assertEquals(Map.of("a", 1), list.get(0).getFunctionArguments());
assertEquals("r1", list.get(0).getResult());
assertEquals("fn2", list.get(1).getFunctionName());
assertEquals(Map.of("b", 2), list.get(1).getFunctionArguments());
assertEquals(123, list.get(1).getResult());
}
@Test
public void testGetToolResultsReturnsEmptyListWhenNull() {
OllamaToolsResult tr = new OllamaToolsResult();
tr.setToolResults(null);
List<OllamaToolsResult.ToolResult> list = tr.getToolResults();
assertNotNull(list);
assertTrue(list.isEmpty());
}
}

View File

@ -0,0 +1,92 @@
package io.github.ollama4j.unittests;
import io.github.ollama4j.utils.Options;
import io.github.ollama4j.utils.OptionsBuilder;
import io.github.ollama4j.utils.PromptBuilder;
import io.github.ollama4j.utils.Utils;
import org.junit.jupiter.api.Test;
import java.io.File;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
class TestOptionsAndUtils {
@Test
void testOptionsBuilderSetsValues() {
Options options = new OptionsBuilder()
.setMirostat(1)
.setMirostatEta(0.2f)
.setMirostatTau(4.5f)
.setNumCtx(1024)
.setNumGqa(8)
.setNumGpu(2)
.setNumThread(6)
.setRepeatLastN(32)
.setRepeatPenalty(1.2f)
.setTemperature(0.7f)
.setSeed(42)
.setStop("STOP")
.setTfsZ(1.5f)
.setNumPredict(256)
.setTopK(50)
.setTopP(0.95f)
.setMinP(0.05f)
.setCustomOption("custom_param", 123)
.build();
Map<String, Object> map = options.getOptionsMap();
assertEquals(1, map.get("mirostat"));
assertEquals(0.2f, (Float) map.get("mirostat_eta"), 0.0001);
assertEquals(4.5f, (Float) map.get("mirostat_tau"), 0.0001);
assertEquals(1024, map.get("num_ctx"));
assertEquals(8, map.get("num_gqa"));
assertEquals(2, map.get("num_gpu"));
assertEquals(6, map.get("num_thread"));
assertEquals(32, map.get("repeat_last_n"));
assertEquals(1.2f, (Float) map.get("repeat_penalty"), 0.0001);
assertEquals(0.7f, (Float) map.get("temperature"), 0.0001);
assertEquals(42, map.get("seed"));
assertEquals("STOP", map.get("stop"));
assertEquals(1.5f, (Float) map.get("tfs_z"), 0.0001);
assertEquals(256, map.get("num_predict"));
assertEquals(50, map.get("top_k"));
assertEquals(0.95f, (Float) map.get("top_p"), 0.0001);
assertEquals(0.05f, (Float) map.get("min_p"), 0.0001);
assertEquals(123, map.get("custom_param"));
}
@Test
void testOptionsBuilderRejectsUnsupportedCustomType() {
OptionsBuilder builder = new OptionsBuilder();
assertThrows(IllegalArgumentException.class, () -> builder.setCustomOption("bad", new Object()));
}
@Test
void testPromptBuilderBuildsExpectedString() {
String prompt = new PromptBuilder()
.add("Hello")
.addLine(", world!")
.addSeparator()
.add("Continue.")
.build();
String expected = "Hello, world!\n\n--------------------------------------------------\nContinue.";
assertEquals(expected, prompt);
}
@Test
void testUtilsGetObjectMapperSingletonAndModule() {
assertSame(Utils.getObjectMapper(), Utils.getObjectMapper());
// Basic serialization sanity check with JavaTimeModule registered
assertDoesNotThrow(() -> Utils.getObjectMapper().writeValueAsString(java.time.OffsetDateTime.now()));
}
@Test
void testGetFileFromClasspath() {
File f = Utils.getFileFromClasspath("test-config.properties");
assertTrue(f.exists());
assertTrue(f.getName().contains("test-config.properties"));
}
}

View File

@ -0,0 +1,86 @@
package io.github.ollama4j.unittests;
import io.github.ollama4j.tools.ReflectionalToolFunction;
import org.junit.jupiter.api.Test;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.util.LinkedHashMap;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
class TestReflectionalToolFunction {
public static class SampleToolHolder {
public String combine(Integer i, Boolean b, BigDecimal d, String s) {
return String.format("i=%s,b=%s,d=%s,s=%s", i, b, d, s);
}
public void alwaysThrows() {
throw new IllegalStateException("boom");
}
}
@Test
void testApplyInvokesMethodWithTypeCasting() throws Exception {
SampleToolHolder holder = new SampleToolHolder();
Method method = SampleToolHolder.class.getMethod("combine", Integer.class, Boolean.class, BigDecimal.class, String.class);
LinkedHashMap<String, String> propDef = new LinkedHashMap<>();
// preserve order to match method parameters
propDef.put("i", "java.lang.Integer");
propDef.put("b", "java.lang.Boolean");
propDef.put("d", "java.math.BigDecimal");
propDef.put("s", "java.lang.String");
ReflectionalToolFunction fn = new ReflectionalToolFunction(holder, method, propDef);
Map<String, Object> args = Map.of(
"i", "42",
"b", "true",
"d", "3.14",
"s", 123 // not a string; should be toString()'d by implementation
);
Object result = fn.apply(args);
assertEquals("i=42,b=true,d=3.14,s=123", result);
}
@Test
void testTypeCastNullsWhenClassOrValueIsNull() throws Exception {
SampleToolHolder holder = new SampleToolHolder();
Method method = SampleToolHolder.class.getMethod("combine", Integer.class, Boolean.class, BigDecimal.class, String.class);
LinkedHashMap<String, String> propDef = new LinkedHashMap<>();
propDef.put("i", null); // className null -> expect null passed
propDef.put("b", "java.lang.Boolean");
propDef.put("d", "java.math.BigDecimal");
propDef.put("s", "java.lang.String");
ReflectionalToolFunction fn = new ReflectionalToolFunction(holder, method, propDef);
Map<String, Object> args = new LinkedHashMap<>();
args.put("i", "100"); // ignored -> becomes null due to null className
args.put("b", null); // value null -> expect null passed
args.put("d", "1.00");
args.put("s", "ok");
Object result = fn.apply(args);
assertEquals("i=null,b=null,d=1.00,s=ok", result);
}
@Test
void testExceptionsAreWrappedWithMeaningfulMessage() throws Exception {
SampleToolHolder holder = new SampleToolHolder();
Method throwsMethod = SampleToolHolder.class.getMethod("alwaysThrows");
LinkedHashMap<String, String> propDef = new LinkedHashMap<>();
ReflectionalToolFunction fn = new ReflectionalToolFunction(holder, throwsMethod, propDef);
RuntimeException ex = assertThrows(RuntimeException.class, () -> fn.apply(Map.of()));
assertTrue(ex.getMessage().contains("Failed to invoke tool: alwaysThrows"));
assertNotNull(ex.getCause());
}
}

View File

@ -0,0 +1,48 @@
package io.github.ollama4j.unittests;
import io.github.ollama4j.tools.ToolFunction;
import io.github.ollama4j.tools.ToolRegistry;
import io.github.ollama4j.tools.Tools;
import org.junit.jupiter.api.Test;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
class TestToolRegistry {
@Test
void testAddAndGetToolFunction() {
ToolRegistry registry = new ToolRegistry();
ToolFunction fn = args -> "ok:" + args.get("x");
Tools.ToolSpecification spec = Tools.ToolSpecification.builder()
.functionName("test")
.functionDescription("desc")
.toolFunction(fn)
.build();
registry.addTool("test", spec);
ToolFunction retrieved = registry.getToolFunction("test");
assertNotNull(retrieved);
assertEquals("ok:42", retrieved.apply(Map.of("x", 42)));
}
@Test
void testGetUnknownReturnsNull() {
ToolRegistry registry = new ToolRegistry();
assertNull(registry.getToolFunction("nope"));
}
@Test
void testClearRemovesAll() {
ToolRegistry registry = new ToolRegistry();
registry.addTool("a", Tools.ToolSpecification.builder().toolFunction(args -> 1).build());
registry.addTool("b", Tools.ToolSpecification.builder().toolFunction(args -> 2).build());
assertFalse(registry.getRegisteredSpecs().isEmpty());
registry.clear();
assertTrue(registry.getRegisteredSpecs().isEmpty());
assertNull(registry.getToolFunction("a"));
assertNull(registry.getToolFunction("b"));
}
}

View File

@ -0,0 +1,64 @@
package io.github.ollama4j.unittests;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.tools.Tools;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
class TestToolsPromptBuilder {
@Test
void testPromptBuilderIncludesToolsAndPrompt() throws JsonProcessingException {
Tools.PromptFuncDefinition.Property cityProp = Tools.PromptFuncDefinition.Property.builder()
.type("string")
.description("city name")
.required(true)
.build();
Tools.PromptFuncDefinition.Property unitsProp = Tools.PromptFuncDefinition.Property.builder()
.type("string")
.description("units")
.enumValues(List.of("metric", "imperial"))
.required(false)
.build();
Tools.PromptFuncDefinition.Parameters params = Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(Map.of("city", cityProp, "units", unitsProp))
.build();
Tools.PromptFuncDefinition.PromptFuncSpec spec = Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("getWeather")
.description("Get weather for a city")
.parameters(params)
.build();
Tools.PromptFuncDefinition def = Tools.PromptFuncDefinition.builder()
.type("function")
.function(spec)
.build();
Tools.ToolSpecification toolSpec = Tools.ToolSpecification.builder()
.functionName("getWeather")
.functionDescription("Get weather for a city")
.toolPrompt(def)
.build();
Tools.PromptBuilder pb = new Tools.PromptBuilder()
.withToolSpecification(toolSpec)
.withPrompt("Tell me the weather.");
String built = pb.build();
assertTrue(built.contains("[AVAILABLE_TOOLS]"));
assertTrue(built.contains("[/AVAILABLE_TOOLS]"));
assertTrue(built.contains("[INST]"));
assertTrue(built.contains("Tell me the weather."));
assertTrue(built.contains("\"name\":\"getWeather\""));
assertTrue(built.contains("\"required\":[\"city\"]"));
assertTrue(built.contains("\"enum\":[\"metric\",\"imperial\"]"));
}
}