Merge pull request #145 from ollama4j/thinking-support

Thinking support
This commit is contained in:
Amith Koujalgi 2025-08-31 17:44:58 +05:30 committed by GitHub
commit 931d5dd520
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 1420 additions and 1098 deletions

View File

@ -1,20 +1,21 @@
name: Run Tests name: Build and Test on Pull Request
on: on:
pull_request: pull_request:
# types: [opened, reopened, synchronize, edited] types: [opened, reopened, synchronize]
branches: [ "main" ] branches:
- main
paths: paths:
- 'src/**' # Run if changes occur in the 'src' folder - 'src/**'
- 'pom.xml' # Run if changes occur in the 'pom.xml' file - 'pom.xml'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
run-tests: build:
name: Build Java Project
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
contents: read contents: read
@ -26,18 +27,26 @@ jobs:
with: with:
java-version: '11' java-version: '11'
distribution: 'adopt-hotspot' distribution: 'adopt-hotspot'
server-id: github # Value of the distributionManagement/repository/id field of the pom.xml server-id: github
settings-path: ${{ github.workspace }} # location for the settings.xml file settings-path: ${{ github.workspace }}
- name: Build with Maven - name: Build with Maven
run: mvn --file pom.xml -U clean package run: mvn --file pom.xml -U clean package
- name: Run unit tests run-tests:
run: mvn --file pom.xml -U clean test -Punit-tests name: Run Unit and Integration Tests
needs: build
uses: ./.github/workflows/run-tests.yml
with:
branch: ${{ github.head_ref || github.ref_name }}
- name: Run integration tests build-docs:
run: mvn --file pom.xml -U clean verify -Pintegration-tests name: Build Documentation
needs: [build, run-tests]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Use Node.js - name: Use Node.js
uses: actions/setup-node@v3 uses: actions/setup-node@v3
with: with:

View File

@ -1,18 +1,29 @@
name: Run Unit and Integration Tests name: Run Tests
on: on:
# push: # push:
# branches: # branches:
# - main # - main
workflow_call:
inputs:
branch:
description: 'Branch name to run the tests on'
required: true
default: 'main'
type: string
workflow_dispatch: workflow_dispatch:
inputs: inputs:
branch: branch:
description: 'Branch name to run the tests on' description: 'Branch name to run the tests on'
required: true required: true
default: 'main' default: 'main'
type: string
jobs: jobs:
run-tests: run-tests:
name: Unit and Integration Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@ -21,17 +32,6 @@ jobs:
with: with:
ref: ${{ github.event.inputs.branch }} ref: ${{ github.event.inputs.branch }}
- name: Use workflow from checked out branch
run: |
if [ -f .github/workflows/run-tests.yml ]; then
echo "Using workflow from checked out branch."
cp .github/workflows/run-tests.yml /tmp/run-tests.yml
exit 0
else
echo "Workflow file not found in checked out branch."
exit 1
fi
- name: Set up Ollama - name: Set up Ollama
run: | run: |
curl -fsSL https://ollama.com/install.sh | sh curl -fsSL https://ollama.com/install.sh | sh
@ -51,4 +51,4 @@ jobs:
run: mvn clean verify -Pintegration-tests run: mvn clean verify -Pintegration-tests
env: env:
USE_EXTERNAL_OLLAMA_HOST: "true" USE_EXTERNAL_OLLAMA_HOST: "true"
OLLAMA_HOST: "http://localhost:11434" OLLAMA_HOST: "http://localhost:11434"

View File

@ -29,7 +29,7 @@ You will get a response similar to:
### Try asking a question, receiving the answer streamed ### Try asking a question, receiving the answer streamed
<CodeEmbed src="https://raw.githubusercontent.com/ollama4j/ollama4j-examples/refs/heads/main/src/main/java/io/github/ollama4j/examples/GenerateStreamingWithTokenConcatenation.java" /> <CodeEmbed src="https://raw.githubusercontent.com/ollama4j/ollama4j-examples/refs/heads/main/src/main/java/io/github/ollama4j/examples/GenerateStreaming.java" />
You will get a response similar to: You will get a response similar to:

22
pom.xml
View File

@ -14,11 +14,12 @@
<properties> <properties>
<maven.compiler.release>11</maven.compiler.release> <maven.compiler.release>11</maven.compiler.release>
<project.build.outputTimestamp>${git.commit.time}</project.build.outputTimestamp><!-- populated via git-commit-id-plugin --> <project.build.outputTimestamp>${git.commit.time}
</project.build.outputTimestamp><!-- populated via git-commit-id-plugin -->
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven-surefire-plugin.version>3.0.0-M5</maven-surefire-plugin.version> <maven-surefire-plugin.version>3.0.0-M5</maven-surefire-plugin.version>
<maven-failsafe-plugin.version>3.0.0-M5</maven-failsafe-plugin.version> <maven-failsafe-plugin.version>3.0.0-M5</maven-failsafe-plugin.version>
<lombok.version>1.18.30</lombok.version> <lombok.version>1.18.38</lombok.version>
</properties> </properties>
<developers> <developers>
@ -46,6 +47,19 @@
<build> <build>
<plugins> <plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<annotationProcessorPaths>
<path>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>${lombok.version}</version>
</path>
</annotationProcessorPaths>
</configuration>
</plugin>
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId> <artifactId>maven-source-plugin</artifactId>
@ -146,7 +160,7 @@
</executions> </executions>
<configuration> <configuration>
<dateFormat>yyyy-MM-dd'T'HH:mm:ss'Z'</dateFormat> <dateFormat>yyyy-MM-dd'T'HH:mm:ss'Z'</dateFormat>
<dateFormatTimeZone>Etc/UTC</dateFormatTimeZone> <dateFormatTimeZone>Etc/UTC</dateFormatTimeZone>
</configuration> </configuration>
</plugin> </plugin>
</plugins> </plugins>
@ -412,4 +426,4 @@
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -22,6 +22,7 @@ import io.github.ollama4j.tools.*;
import io.github.ollama4j.tools.annotations.OllamaToolService; import io.github.ollama4j.tools.annotations.OllamaToolService;
import io.github.ollama4j.tools.annotations.ToolProperty; import io.github.ollama4j.tools.annotations.ToolProperty;
import io.github.ollama4j.tools.annotations.ToolSpec; import io.github.ollama4j.tools.annotations.ToolSpec;
import io.github.ollama4j.utils.Constants;
import io.github.ollama4j.utils.Options; import io.github.ollama4j.utils.Options;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import lombok.Setter; import lombok.Setter;
@ -55,33 +56,54 @@ import java.util.stream.Collectors;
public class OllamaAPI { public class OllamaAPI {
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
private final String host; private final String host;
private Auth auth;
private final ToolRegistry toolRegistry = new ToolRegistry();
/** /**
* -- SETTER -- * The request timeout in seconds for API calls.
* Set request timeout in seconds. Default is 3 seconds. * <p>
* Default is 10 seconds. This value determines how long the client will wait
* for a response
* from the Ollama server before timing out.
*/ */
@Setter @Setter
private long requestTimeoutSeconds = 10; private long requestTimeoutSeconds = 10;
/** /**
* -- SETTER -- * Enables or disables verbose logging of responses.
* Set/unset logging of responses * <p>
* If set to {@code true}, the API will log detailed information about requests
* and responses.
* Default is {@code true}.
*/ */
@Setter @Setter
private boolean verbose = true; private boolean verbose = true;
/**
* The maximum number of retries for tool calls during chat interactions.
* <p>
* This value controls how many times the API will attempt to call a tool in the
* event of a failure.
* Default is 3.
*/
@Setter @Setter
private int maxChatToolCallRetries = 3; private int maxChatToolCallRetries = 3;
private Auth auth; /**
* The number of retries to attempt when pulling a model from the Ollama server.
* <p>
* If set to 0, no retries will be performed. If greater than 0, the API will
* retry pulling the model
* up to the specified number of times in case of failure.
* <p>
* Default is 0 (no retries).
*/
@Setter
@SuppressWarnings({"FieldMayBeFinal", "FieldCanBeLocal"})
private int numberOfRetriesForModelPull = 0; private int numberOfRetriesForModelPull = 0;
public void setNumberOfRetriesForModelPull(int numberOfRetriesForModelPull) {
this.numberOfRetriesForModelPull = numberOfRetriesForModelPull;
}
private final ToolRegistry toolRegistry = new ToolRegistry();
/** /**
* Instantiates the Ollama API with default Ollama host: * Instantiates the Ollama API with default Ollama host:
* <a href="http://localhost:11434">http://localhost:11434</a> * <a href="http://localhost:11434">http://localhost:11434</a>
@ -102,7 +124,7 @@ public class OllamaAPI {
this.host = host; this.host = host;
} }
if (this.verbose) { if (this.verbose) {
logger.info("Ollama API initialized with host: " + this.host); logger.info("Ollama API initialized with host: {}", this.host);
} }
} }
@ -135,14 +157,17 @@ public class OllamaAPI {
public boolean ping() { public boolean ping() {
String url = this.host + "/api/tags"; String url = this.host + "/api/tags";
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = null; HttpRequest httpRequest;
try { try {
httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") httpRequest = getRequestBuilderDefault(new URI(url))
.header("Content-type", "application/json").GET().build(); .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.GET()
.build();
} catch (URISyntaxException e) { } catch (URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
HttpResponse<String> response = null; HttpResponse<String> response;
try { try {
response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
} catch (HttpConnectTimeoutException e) { } catch (HttpConnectTimeoutException e) {
@ -168,8 +193,10 @@ public class OllamaAPI {
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = null; HttpRequest httpRequest = null;
try { try {
httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") httpRequest = getRequestBuilderDefault(new URI(url))
.header("Content-type", "application/json").GET().build(); .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.GET().build();
} catch (URISyntaxException e) { } catch (URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -196,8 +223,10 @@ public class OllamaAPI {
public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = this.host + "/api/tags"; String url = this.host + "/api/tags";
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
.header("Content-type", "application/json").GET().build(); .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
.build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseString = response.body(); String responseString = response.body();
@ -229,8 +258,10 @@ public class OllamaAPI {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = "https://ollama.com/library"; String url = "https://ollama.com/library";
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
.header("Content-type", "application/json").GET().build(); .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
.build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseString = response.body(); String responseString = response.body();
@ -296,8 +327,10 @@ public class OllamaAPI {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName()); String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName());
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
.header("Content-type", "application/json").GET().build(); .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
.build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseString = response.body(); String responseString = response.body();
@ -338,6 +371,14 @@ public class OllamaAPI {
/** /**
* Finds a specific model using model name and tag from Ollama library. * Finds a specific model using model name and tag from Ollama library.
* <p> * <p>
* <b>Deprecated:</b> This method relies on the HTML structure of the Ollama
* website,
* which is subject to change at any time. As a result, it is difficult to keep
* this API
* method consistently updated and reliable. Therefore, this method is
* deprecated and
* may be removed in future releases.
* <p>
* This method retrieves the model from the Ollama library by its name, then * This method retrieves the model from the Ollama library by its name, then
* fetches its tags. * fetches its tags.
* It searches through the tags of the model to find one that matches the * It searches through the tags of the model to find one that matches the
@ -355,7 +396,11 @@ public class OllamaAPI {
* @throws URISyntaxException If there is an error with the URI syntax. * @throws URISyntaxException If there is an error with the URI syntax.
* @throws InterruptedException If the operation is interrupted. * @throws InterruptedException If the operation is interrupted.
* @throws NoSuchElementException If the model or the tag is not found. * @throws NoSuchElementException If the model or the tag is not found.
* @deprecated This method relies on the HTML structure of the Ollama website,
* which can change at any time and break this API. It is deprecated
* and may be removed in the future.
*/ */
@Deprecated
public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) public LibraryModelTag findModelTagFromLibrary(String modelName, String tag)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
List<LibraryModel> libraryModels = this.listModelsFromLibrary(); List<LibraryModel> libraryModels = this.listModelsFromLibrary();
@ -363,40 +408,71 @@ public class OllamaAPI {
.findFirst().orElseThrow( .findFirst().orElseThrow(
() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName))); () -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName)));
LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel); LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel);
LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream() return libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst()
.filter(tagName -> tagName.getTag().equals(tag)).findFirst()
.orElseThrow(() -> new NoSuchElementException( .orElseThrow(() -> new NoSuchElementException(
String.format("Tag '%s' for model '%s' not found", tag, modelName))); String.format("Tag '%s' for model '%s' not found", tag, modelName)));
return libraryModelTag;
} }
/** /**
* Pull a model on the Ollama server from the list of <a * Pull a model on the Ollama server from the list of <a
* href="https://ollama.ai/library">available models</a>. * href="https://ollama.ai/library">available models</a>.
* <p>
* If {@code numberOfRetriesForModelPull} is greater than 0, this method will
* retry pulling the model
* up to the specified number of times if an {@link OllamaBaseException} occurs,
* using exponential backoff
* between retries (delay doubles after each failed attempt, starting at 1
* second).
* <p>
* The backoff is only applied between retries, not after the final attempt.
* *
* @param modelName the name of the model * @param modelName the name of the model
* @throws OllamaBaseException if the response indicates an error status * @throws OllamaBaseException if the response indicates an error status or all
* retries fail
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted or the thread is
* interrupted during backoff
* @throws URISyntaxException if the URI for the request is malformed * @throws URISyntaxException if the URI for the request is malformed
*/ */
public void pullModel(String modelName) public void pullModel(String modelName)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
if (numberOfRetriesForModelPull == 0) { if (numberOfRetriesForModelPull == 0) {
this.doPullModel(modelName); this.doPullModel(modelName);
} else { return;
int numberOfRetries = 0; }
while (numberOfRetries < numberOfRetriesForModelPull) { int numberOfRetries = 0;
try { long baseDelayMillis = 3000L; // 1 second base delay
this.doPullModel(modelName); while (numberOfRetries < numberOfRetriesForModelPull) {
return; try {
} catch (OllamaBaseException e) { this.doPullModel(modelName);
logger.error("Failed to pull model " + modelName + ", retrying..."); return;
numberOfRetries++; } catch (OllamaBaseException e) {
} handlePullRetry(modelName, numberOfRetries, numberOfRetriesForModelPull, baseDelayMillis);
numberOfRetries++;
} }
throw new OllamaBaseException( }
"Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries"); throw new OllamaBaseException(
"Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries");
}
/**
* Handles retry backoff for pullModel.
*/
private void handlePullRetry(String modelName, int currentRetry, int maxRetries, long baseDelayMillis)
throws InterruptedException {
int attempt = currentRetry + 1;
if (attempt < maxRetries) {
long backoffMillis = baseDelayMillis * (1L << currentRetry);
logger.error("Failed to pull model {}, retrying in {}s... (attempt {}/{})",
modelName, backoffMillis / 1000, attempt, maxRetries);
try {
Thread.sleep(backoffMillis);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw ie;
}
} else {
logger.error("Failed to pull model {} after {} attempts, no more retries.", modelName, maxRetries);
} }
} }
@ -404,10 +480,9 @@ public class OllamaAPI {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String url = this.host + "/api/pull"; String url = this.host + "/api/pull";
String jsonData = new ModelRequest(modelName).toString(); String jsonData = new ModelRequest(modelName).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)) HttpRequest request = getRequestBuilderDefault(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData))
.POST(HttpRequest.BodyPublishers.ofString(jsonData)) .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header("Accept", "application/json") .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.header("Content-type", "application/json")
.build(); .build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
@ -428,7 +503,7 @@ public class OllamaAPI {
if (modelPullResponse.getStatus() != null) { if (modelPullResponse.getStatus() != null) {
if (verbose) { if (verbose) {
logger.info(modelName + ": " + modelPullResponse.getStatus()); logger.info("{}: {}", modelName, modelPullResponse.getStatus());
} }
// Check if status is "success" and set success flag to true. // Check if status is "success" and set success flag to true.
if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) { if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) {
@ -452,8 +527,10 @@ public class OllamaAPI {
public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException { public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/version"; String url = this.host + "/api/version";
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
.header("Content-type", "application/json").GET().build(); .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
.build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseString = response.body(); String responseString = response.body();
@ -498,8 +575,10 @@ public class OllamaAPI {
throws IOException, OllamaBaseException, InterruptedException, URISyntaxException { throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
String url = this.host + "/api/show"; String url = this.host + "/api/show";
String jsonData = new ModelRequest(modelName).toString(); String jsonData = new ModelRequest(modelName).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest request = getRequestBuilderDefault(new URI(url))
.header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -529,8 +608,9 @@ public class OllamaAPI {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create"; String url = this.host + "/api/create";
String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString(); String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest request = getRequestBuilderDefault(new URI(url))
.header("Content-Type", "application/json") .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
@ -569,8 +649,9 @@ public class OllamaAPI {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create"; String url = this.host + "/api/create";
String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString(); String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest request = getRequestBuilderDefault(new URI(url))
.header("Content-Type", "application/json") .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
@ -602,8 +683,9 @@ public class OllamaAPI {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create"; String url = this.host + "/api/create";
String jsonData = customModelRequest.toString(); String jsonData = customModelRequest.toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest request = getRequestBuilderDefault(new URI(url))
.header("Content-Type", "application/json") .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
@ -637,7 +719,9 @@ public class OllamaAPI {
String jsonData = new ModelRequest(modelName).toString(); String jsonData = new ModelRequest(modelName).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)) HttpRequest request = getRequestBuilderDefault(new URI(url))
.method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.header("Accept", "application/json").header("Content-type", "application/json").build(); .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -683,7 +767,8 @@ public class OllamaAPI {
URI uri = URI.create(this.host + "/api/embeddings"); URI uri = URI.create(this.host + "/api/embeddings");
String jsonData = modelRequest.toString(); String jsonData = modelRequest.toString();
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json") HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri)
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData)); .POST(HttpRequest.BodyPublishers.ofString(jsonData));
HttpRequest request = requestBuilder.build(); HttpRequest request = requestBuilder.build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
@ -728,7 +813,8 @@ public class OllamaAPI {
String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest); String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json") HttpRequest request = HttpRequest.newBuilder(uri)
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
@ -744,33 +830,112 @@ public class OllamaAPI {
/** /**
* Generate response for a question to a model running on Ollama server. This is * Generate response for a question to a model running on Ollama server. This is
* a sync/blocking * a sync/blocking call. This API does not support "thinking" models.
* call.
* *
* @param model the ollama model to ask the question to * @param model the ollama model to ask the question to
* @param prompt the prompt/question text * @param prompt the prompt/question text
* @param options the Options object - <a * @param raw if true no formatting will be applied to the
* href= * prompt. You
* "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More * may choose to use the raw parameter if you are
* details on the options</a> * specifying a full templated prompt in your
* @param streamHandler optional callback consumer that will be applied every * request to
* time a streamed response is received. If not set, the * the API
* stream parameter of the request is set to false. * @param options the Options object - <a
* href=
* "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
* details on the options</a>
* @param responseStreamHandler optional callback consumer that will be applied
* every
* time a streamed response is received. If not
* set, the
* stream parameter of the request is set to false.
* @return OllamaResult that includes response text and time taken for response * @return OllamaResult that includes response text and time taken for response
* @throws OllamaBaseException if the response indicates an error status * @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generate(String model, String prompt, boolean raw, Options options, public OllamaResult generate(String model, String prompt, boolean raw, Options options,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaStreamHandler responseStreamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw); ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setThink(false);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); return generateSyncForOllamaRequestModel(ollamaRequestModel, null, responseStreamHandler);
}
/**
* Generate thinking and response tokens for a question to a thinking model
* running on Ollama server. This is
* a sync/blocking call.
*
* @param model the ollama model to ask the question to
* @param prompt the prompt/question text
* @param raw if true no formatting will be applied to the
* prompt. You
* may choose to use the raw parameter if you are
* specifying a full templated prompt in your
* request to
* the API
* @param options the Options object - <a
* href=
* "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
* details on the options</a>
* @param responseStreamHandler optional callback consumer that will be applied
* every
* time a streamed response is received. If not
* set, the
* stream parameter of the request is set to false.
* @return OllamaResult that includes response text and time taken for response
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generate(String model, String prompt, boolean raw, Options options,
OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler)
throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setThink(true);
ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
}
/**
* Generates response using the specified AI model and prompt (in blocking
* mode).
* <p>
* Uses
* {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
*
* @param model The name or identifier of the AI model to use for generating
* the response.
* @param prompt The input text or prompt to provide to the AI model.
* @param raw In some cases, you may wish to bypass the templating system
* and provide a full prompt. In this case, you can use the raw
* parameter to disable templating. Also note that raw mode will
* not return a context.
* @param options Additional options or configurations to use when generating
* the response.
* @param think if true the model will "think" step-by-step before
* generating the final response
* @return {@link OllamaResult}
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options)
throws OllamaBaseException, IOException, InterruptedException {
if (think) {
return generate(model, prompt, raw, options, null, null);
} else {
return generate(model, prompt, raw, options, null);
}
} }
/** /**
* Generates structured output from the specified AI model and prompt. * Generates structured output from the specified AI model and prompt.
* <p>
* Note: When formatting is specified, the 'think' parameter is not allowed.
* *
* @param model The name or identifier of the AI model to use for generating * @param model The name or identifier of the AI model to use for generating
* the response. * the response.
@ -783,6 +948,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request. * @throws IOException if an I/O error occurs during the HTTP request.
* @throws InterruptedException if the operation is interrupted. * @throws InterruptedException if the operation is interrupted.
*/ */
@SuppressWarnings("LoggingSimilarMessage")
public OllamaResult generate(String model, String prompt, Map<String, Object> format) public OllamaResult generate(String model, String prompt, Map<String, Object> format)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
@ -797,51 +963,52 @@ public class OllamaAPI {
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest request = getRequestBuilderDefault(uri) HttpRequest request = getRequestBuilderDefault(uri)
.header("Accept", "application/json") .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header("Content-type", "application/json") .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData)) .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
.build();
if (verbose) {
try {
String prettyJson = Utils.getObjectMapper().writerWithDefaultPrettyPrinter()
.writeValueAsString(Utils.getObjectMapper().readValue(jsonData, Object.class));
logger.info("Asking model:\n{}", prettyJson);
} catch (Exception e) {
logger.info("Asking model: {}", jsonData);
}
}
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseBody = response.body(); String responseBody = response.body();
if (statusCode == 200) { if (statusCode == 200) {
OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody,
OllamaStructuredResult.class); OllamaStructuredResult.class);
OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(),
structuredResult.getResponseTime(), statusCode); structuredResult.getResponseTime(), statusCode);
ollamaResult.setModel(structuredResult.getModel());
ollamaResult.setCreatedAt(structuredResult.getCreatedAt());
ollamaResult.setDone(structuredResult.isDone());
ollamaResult.setDoneReason(structuredResult.getDoneReason());
ollamaResult.setContext(structuredResult.getContext());
ollamaResult.setTotalDuration(structuredResult.getTotalDuration());
ollamaResult.setLoadDuration(structuredResult.getLoadDuration());
ollamaResult.setPromptEvalCount(structuredResult.getPromptEvalCount());
ollamaResult.setPromptEvalDuration(structuredResult.getPromptEvalDuration());
ollamaResult.setEvalCount(structuredResult.getEvalCount());
ollamaResult.setEvalDuration(structuredResult.getEvalDuration());
if (verbose) {
logger.info("Model response:\n{}", ollamaResult);
}
return ollamaResult; return ollamaResult;
} else { } else {
if (verbose) {
logger.info("Model response:\n{}",
Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseBody));
}
throw new OllamaBaseException(statusCode + " - " + responseBody); throw new OllamaBaseException(statusCode + " - " + responseBody);
} }
} }
/**
* Generates response using the specified AI model and prompt (in blocking
* mode).
* <p>
* Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
*
* @param model The name or identifier of the AI model to use for generating
* the response.
* @param prompt The input text or prompt to provide to the AI model.
* @param raw In some cases, you may wish to bypass the templating system
* and provide a full prompt. In this case, you can use the raw
* parameter to disable templating. Also note that raw mode will
* not return a context.
* @param options Additional options or configurations to use when generating
* the response.
* @return {@link OllamaResult}
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generate(String model, String prompt, boolean raw, Options options)
throws OllamaBaseException, IOException, InterruptedException {
return generate(model, prompt, raw, options, null);
}
/** /**
* Generates response using the specified AI model and prompt (in blocking * Generates response using the specified AI model and prompt (in blocking
* mode), and then invokes a set of tools * mode), and then invokes a set of tools
@ -893,8 +1060,7 @@ public class OllamaAPI {
logger.warn("Response from model does not contain any tool calls. Returning the response as is."); logger.warn("Response from model does not contain any tool calls. Returning the response as is.");
return toolResult; return toolResult;
} }
toolFunctionCallSpecs = objectMapper.readValue( toolFunctionCallSpecs = objectMapper.readValue(toolsResponse,
toolsResponse,
objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class)); objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
} }
for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) { for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
@ -905,19 +1071,47 @@ public class OllamaAPI {
} }
/** /**
* Generate response for a question to a model running on Ollama server and get * Asynchronously generates a response for a prompt using a model running on the
* a callback handle * Ollama server.
* that can be used to check for status and get the response from the model * <p>
* later. This would be * This method returns an {@link OllamaAsyncResultStreamer} handle that can be
* an async/non-blocking call. * used to poll for
* status and retrieve streamed "thinking" and response tokens from the model.
* The call is non-blocking.
* </p>
* *
* @param model the ollama model to ask the question to * <p>
* @param prompt the prompt/question text * <b>Example usage:</b>
* @return the ollama async result callback handle * </p>
*
* <pre>{@code
* OllamaAsyncResultStreamer resultStreamer = ollamaAPI.generateAsync("gpt-oss:20b", "Who are you", false, true);
* int pollIntervalMilliseconds = 1000;
* while (true) {
* String thinkingTokens = resultStreamer.getThinkingResponseStream().poll();
* String responseTokens = resultStreamer.getResponseStream().poll();
* System.out.print(thinkingTokens != null ? thinkingTokens.toUpperCase() : "");
* System.out.print(responseTokens != null ? responseTokens.toLowerCase() : "");
* Thread.sleep(pollIntervalMilliseconds);
* if (!resultStreamer.isAlive())
* break;
* }
* System.out.println("Complete thinking response: " + resultStreamer.getCompleteThinkingResponse());
* System.out.println("Complete response: " + resultStreamer.getCompleteResponse());
* }</pre>
*
* @param model the Ollama model to use for generating the response
* @param prompt the prompt or question text to send to the model
* @param raw if {@code true}, returns the raw response from the model
* @param think if {@code true}, streams "thinking" tokens as well as response
* tokens
* @return an {@link OllamaAsyncResultStreamer} handle for polling and
* retrieving streamed results
*/ */
public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw) { public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw, boolean think) {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw); ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setThink(think);
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer( OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(
getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds); getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
@ -953,7 +1147,7 @@ public class OllamaAPI {
} }
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images); OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); return generateSyncForOllamaRequestModel(ollamaRequestModel, null, streamHandler);
} }
/** /**
@ -1001,7 +1195,7 @@ public class OllamaAPI {
} }
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images); OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); return generateSyncForOllamaRequestModel(ollamaRequestModel, null, streamHandler);
} }
/** /**
@ -1023,38 +1217,47 @@ public class OllamaAPI {
/** /**
* Synchronously generates a response using a list of image byte arrays. * Synchronously generates a response using a list of image byte arrays.
* <p> * <p>
* This method encodes the provided byte arrays into Base64 and sends them to the Ollama server. * This method encodes the provided byte arrays into Base64 and sends them to
* the Ollama server.
* *
* @param model the Ollama model to use for generating the response * @param model the Ollama model to use for generating the response
* @param prompt the prompt or question text to send to the model * @param prompt the prompt or question text to send to the model
* @param images the list of image data as byte arrays * @param images the list of image data as byte arrays
* @param options the Options object - <a href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More details on the options</a> * @param options the Options object - <a href=
* @param streamHandler optional callback that will be invoked with each streamed response; if null, streaming is disabled * "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
* @return OllamaResult containing the response text and the time taken for the response * details on the options</a>
* @param streamHandler optional callback that will be invoked with each
* streamed response; if null, streaming is disabled
* @return OllamaResult containing the response text and the time taken for the
* response
* @throws OllamaBaseException if the response indicates an error status * @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
List<String> encodedImages = new ArrayList<>(); List<String> encodedImages = new ArrayList<>();
for (byte[] image : images) { for (byte[] image : images) {
encodedImages.add(encodeByteArrayToBase64(image)); encodedImages.add(encodeByteArrayToBase64(image));
} }
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, encodedImages); OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, encodedImages);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); return generateSyncForOllamaRequestModel(ollamaRequestModel, null, streamHandler);
} }
/** /**
* Convenience method to call the Ollama API using image byte arrays without streaming responses. * Convenience method to call the Ollama API using image byte arrays without
* streaming responses.
* <p> * <p>
* Uses {@link #generateWithImages(String, String, List, Options, OllamaStreamHandler)} * Uses
* {@link #generateWithImages(String, String, List, Options, OllamaStreamHandler)}
* *
* @throws OllamaBaseException if the response indicates an error status * @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options) throws OllamaBaseException, IOException, InterruptedException { public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options)
throws OllamaBaseException, IOException, InterruptedException {
return generateWithImages(model, prompt, images, options, null); return generateWithImages(model, prompt, images, options, null);
} }
@ -1069,10 +1272,12 @@ public class OllamaAPI {
* history including the newly acquired assistant response. * history including the newly acquired assistant response.
* @throws OllamaBaseException any response code than 200 has been returned * @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network * @throws InterruptedException in case the server is not reachable or
* network
* issues happen * issues happen
* @throws OllamaBaseException if the response indicates an error status * @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP
* request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails * @throws ToolInvocationException if the tool invocation fails
*/ */
@ -1092,16 +1297,18 @@ public class OllamaAPI {
* @return {@link OllamaChatResult} * @return {@link OllamaChatResult}
* @throws OllamaBaseException any response code than 200 has been returned * @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network * @throws InterruptedException in case the server is not reachable or
* network
* issues happen * issues happen
* @throws OllamaBaseException if the response indicates an error status * @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP
* request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails * @throws ToolInvocationException if the tool invocation fails
*/ */
public OllamaChatResult chat(OllamaChatRequest request) public OllamaChatResult chat(OllamaChatRequest request)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
return chat(request, null); return chat(request, null, null);
} }
/** /**
@ -1110,23 +1317,27 @@ public class OllamaAPI {
* <p> * <p>
* Hint: the OllamaChatRequestModel#getStream() property is not implemented. * Hint: the OllamaChatRequestModel#getStream() property is not implemented.
* *
* @param request request object to be sent to the server * @param request request object to be sent to the server
* @param streamHandler callback handler to handle the last message from stream * @param responseStreamHandler callback handler to handle the last message from
* (caution: all previous tokens from stream will be * stream
* concatenated) * @param thinkingStreamHandler callback handler to handle the last thinking
* message from stream
* @return {@link OllamaChatResult} * @return {@link OllamaChatResult}
* @throws OllamaBaseException any response code than 200 has been returned * @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network * @throws InterruptedException in case the server is not reachable or
* network
* issues happen * issues happen
* @throws OllamaBaseException if the response indicates an error status * @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP
* request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails * @throws ToolInvocationException if the tool invocation fails
*/ */
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler thinkingStreamHandler,
OllamaStreamHandler responseStreamHandler)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
return chatStreaming(request, new OllamaChatStreamObserver(streamHandler)); return chatStreaming(request, new OllamaChatStreamObserver(thinkingStreamHandler, responseStreamHandler));
} }
/** /**
@ -1177,8 +1388,11 @@ public class OllamaAPI {
} }
Map<String, Object> arguments = toolCall.getFunction().getArguments(); Map<String, Object> arguments = toolCall.getFunction().getArguments();
Object res = toolFunction.apply(arguments); Object res = toolFunction.apply(arguments);
String argumentKeys = arguments.keySet().stream()
.map(Object::toString)
.collect(Collectors.joining(", "));
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,
"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]")); "[TOOL_RESULTS] " + toolName + "(" + argumentKeys + "): " + res + " [/TOOL_RESULTS]"));
} }
if (tokenHandler != null) { if (tokenHandler != null) {
@ -1224,6 +1438,17 @@ public class OllamaAPI {
} }
} }
/**
* Deregisters all tools from the tool registry.
* This method removes all registered tools, effectively clearing the registry.
*/
public void deregisterTools() {
toolRegistry.clear();
if (this.verbose) {
logger.debug("All tools have been deregistered.");
}
}
/** /**
* Registers tools based on the annotations found on the methods of the caller's * Registers tools based on the annotations found on the methods of the caller's
* class and its providers. * class and its providers.
@ -1380,10 +1605,12 @@ public class OllamaAPI {
* the request will be streamed; otherwise, a regular synchronous request will * the request will be streamed; otherwise, a regular synchronous request will
* be made. * be made.
* *
* @param ollamaRequestModel the request model containing necessary parameters * @param ollamaRequestModel the request model containing necessary
* for the Ollama API request. * parameters
* @param streamHandler the stream handler to process streaming responses, * for the Ollama API request.
* or null for non-streaming requests. * @param responseStreamHandler the stream handler to process streaming
* responses,
* or null for non-streaming requests.
* @return the result of the Ollama API request. * @return the result of the Ollama API request.
* @throws OllamaBaseException if the request fails due to an issue with the * @throws OllamaBaseException if the request fails due to an issue with the
* Ollama API. * Ollama API.
@ -1392,13 +1619,14 @@ public class OllamaAPI {
* @throws InterruptedException if the thread is interrupted during the request. * @throws InterruptedException if the thread is interrupted during the request.
*/ */
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler)
throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds,
verbose); verbose);
OllamaResult result; OllamaResult result;
if (streamHandler != null) { if (responseStreamHandler != null) {
ollamaRequestModel.setStream(true); ollamaRequestModel.setStream(true);
result = requestCaller.call(ollamaRequestModel, streamHandler); result = requestCaller.call(ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
} else { } else {
result = requestCaller.callSync(ollamaRequestModel); result = requestCaller.callSync(ollamaRequestModel);
} }
@ -1412,7 +1640,8 @@ public class OllamaAPI {
* @return HttpRequest.Builder * @return HttpRequest.Builder
*/ */
private HttpRequest.Builder getRequestBuilderDefault(URI uri) { private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json") HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.timeout(Duration.ofSeconds(requestTimeoutSeconds)); .timeout(Duration.ofSeconds(requestTimeoutSeconds));
if (isBasicAuthCredentialsSet()) { if (isBasicAuthCredentialsSet()) {
requestBuilder.header("Authorization", auth.getAuthHeaderValue()); requestBuilder.header("Authorization", auth.getAuthHeaderValue());

View File

@ -3,12 +3,8 @@ package io.github.ollama4j.impl;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaStreamHandler;
public class ConsoleOutputStreamHandler implements OllamaStreamHandler { public class ConsoleOutputStreamHandler implements OllamaStreamHandler {
private final StringBuffer response = new StringBuffer();
@Override @Override
public void accept(String message) { public void accept(String message) {
String substr = message.substring(response.length()); System.out.print(message);
response.append(substr);
System.out.print(substr);
} }
} }

View File

@ -1,21 +1,15 @@
package io.github.ollama4j.models.chat; package io.github.ollama4j.models.chat;
import static io.github.ollama4j.utils.Utils.getObjectMapper; import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import io.github.ollama4j.utils.FileToBase64Serializer; import io.github.ollama4j.utils.FileToBase64Serializer;
import lombok.*;
import java.util.List; import java.util.List;
import lombok.AllArgsConstructor; import static io.github.ollama4j.utils.Utils.getObjectMapper;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
/** /**
* Defines a single Message to be used inside a chat request against the ollama /api/chat endpoint. * Defines a single Message to be used inside a chat request against the ollama /api/chat endpoint.
@ -35,6 +29,8 @@ public class OllamaChatMessage {
@NonNull @NonNull
private String content; private String content;
private String thinking;
private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls; private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
@JsonSerialize(using = FileToBase64Serializer.class) @JsonSerialize(using = FileToBase64Serializer.class)

View File

@ -1,43 +1,46 @@
package io.github.ollama4j.models.chat; package io.github.ollama4j.models.chat;
import java.util.List;
import io.github.ollama4j.models.request.OllamaCommonRequest; import io.github.ollama4j.models.request.OllamaCommonRequest;
import io.github.ollama4j.tools.Tools; import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.OllamaRequestBody; import io.github.ollama4j.utils.OllamaRequestBody;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import java.util.List;
/** /**
* Defines a Request to use against the ollama /api/chat endpoint. * Defines a Request to use against the ollama /api/chat endpoint.
* *
* @see <a href= * @see <a href=
* "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate * "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate
* Chat Completion</a> * Chat Completion</a>
*/ */
@Getter @Getter
@Setter @Setter
public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequestBody { public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequestBody {
private List<OllamaChatMessage> messages; private List<OllamaChatMessage> messages;
private List<Tools.PromptFuncDefinition> tools; private List<Tools.PromptFuncDefinition> tools;
public OllamaChatRequest() {} private boolean think;
public OllamaChatRequest(String model, List<OllamaChatMessage> messages) { public OllamaChatRequest() {
this.model = model;
this.messages = messages;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof OllamaChatRequest)) {
return false;
} }
return this.toString().equals(o.toString()); public OllamaChatRequest(String model, boolean think, List<OllamaChatMessage> messages) {
} this.model = model;
this.messages = messages;
this.think = think;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof OllamaChatRequest)) {
return false;
}
return this.toString().equals(o.toString());
}
} }

View File

@ -22,7 +22,7 @@ public class OllamaChatRequestBuilder {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class);
private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages) { private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages) {
request = new OllamaChatRequest(model, messages); request = new OllamaChatRequest(model, false, messages);
} }
private OllamaChatRequest request; private OllamaChatRequest request;
@ -36,14 +36,20 @@ public class OllamaChatRequestBuilder {
} }
public void reset() { public void reset() {
request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); request = new OllamaChatRequest(request.getModel(), request.isThink(), new ArrayList<>());
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){ public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content) {
return withMessage(role,content, Collections.emptyList()); return withMessage(role, content, Collections.emptyList());
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) { public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls) {
List<OllamaChatMessage> messages = this.request.getMessages();
messages.add(new OllamaChatMessage(role, content, null, toolCalls, null));
return this;
}
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls, List<File> images) {
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = images.stream().map(file -> { List<byte[]> binaryImages = images.stream().map(file -> {
@ -55,11 +61,11 @@ public class OllamaChatRequestBuilder {
} }
}).collect(Collectors.toList()); }).collect(Collectors.toList());
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages)); messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages));
return this; return this;
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List<OllamaChatToolCalls> toolCalls, String... imageUrls) { public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls, String... imageUrls) {
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = null; List<byte[]> binaryImages = null;
if (imageUrls.length > 0) { if (imageUrls.length > 0) {
@ -75,7 +81,7 @@ public class OllamaChatRequestBuilder {
} }
} }
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages)); messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages));
return this; return this;
} }
@ -108,4 +114,8 @@ public class OllamaChatRequestBuilder {
return this; return this;
} }
public OllamaChatRequestBuilder withThinking(boolean think) {
this.request.setThink(think);
return this;
}
} }

View File

@ -1,10 +1,10 @@
package io.github.ollama4j.models.chat; package io.github.ollama4j.models.chat;
import java.util.List;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.Getter; import lombok.Getter;
import java.util.List;
import static io.github.ollama4j.utils.Utils.getObjectMapper; import static io.github.ollama4j.utils.Utils.getObjectMapper;
/** /**

View File

@ -6,14 +6,46 @@ import lombok.RequiredArgsConstructor;
@RequiredArgsConstructor @RequiredArgsConstructor
public class OllamaChatStreamObserver implements OllamaTokenHandler { public class OllamaChatStreamObserver implements OllamaTokenHandler {
private final OllamaStreamHandler streamHandler; private final OllamaStreamHandler thinkingStreamHandler;
private final OllamaStreamHandler responseStreamHandler;
private String message = ""; private String message = "";
@Override @Override
public void accept(OllamaChatResponseModel token) { public void accept(OllamaChatResponseModel token) {
if (streamHandler != null) { if (responseStreamHandler == null || token == null || token.getMessage() == null) {
message += token.getMessage().getContent(); return;
streamHandler.accept(message); }
String thinking = token.getMessage().getThinking();
String content = token.getMessage().getContent();
boolean hasThinking = thinking != null && !thinking.isEmpty();
boolean hasContent = !content.isEmpty();
// if (hasThinking && !hasContent) {
//// message += thinking;
// message = thinking;
// } else {
//// message += content;
// message = content;
// }
//
// responseStreamHandler.accept(message);
if (!hasContent && hasThinking && thinkingStreamHandler != null) {
// message = message + thinking;
// use only new tokens received, instead of appending the tokens to the previous
// ones and sending the full string again
thinkingStreamHandler.accept(thinking);
} else if (hasContent && responseStreamHandler != null) {
// message = message + response;
// use only new tokens received, instead of appending the tokens to the previous
// ones and sending the full string again
responseStreamHandler.accept(content);
} }
} }
} }

View File

@ -1,9 +1,9 @@
package io.github.ollama4j.models.embeddings; package io.github.ollama4j.models.embeddings;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.List; import java.util.List;
import lombok.Data;
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Data @Data

View File

@ -1,7 +1,5 @@
package io.github.ollama4j.models.embeddings; package io.github.ollama4j.models.embeddings;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
import java.util.Map;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.Data; import lombok.Data;
@ -9,6 +7,10 @@ import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import java.util.Map;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
@Data @Data
@RequiredArgsConstructor @RequiredArgsConstructor
@NoArgsConstructor @NoArgsConstructor

View File

@ -3,12 +3,11 @@ package io.github.ollama4j.models.generate;
import io.github.ollama4j.models.request.OllamaCommonRequest; import io.github.ollama4j.models.request.OllamaCommonRequest;
import io.github.ollama4j.utils.OllamaRequestBody; import io.github.ollama4j.utils.OllamaRequestBody;
import java.util.List;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import java.util.List;
@Getter @Getter
@Setter @Setter
public class OllamaGenerateRequest extends OllamaCommonRequest implements OllamaRequestBody{ public class OllamaGenerateRequest extends OllamaCommonRequest implements OllamaRequestBody{
@ -19,6 +18,7 @@ public class OllamaGenerateRequest extends OllamaCommonRequest implements Ollama
private String system; private String system;
private String context; private String context;
private boolean raw; private boolean raw;
private boolean think;
public OllamaGenerateRequest() { public OllamaGenerateRequest() {
} }

View File

@ -2,9 +2,9 @@ package io.github.ollama4j.models.generate;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.List; import java.util.List;
import lombok.Data;
@Data @Data
@JsonIgnoreProperties(ignoreUnknown = true) @JsonIgnoreProperties(ignoreUnknown = true)
@ -12,12 +12,14 @@ public class OllamaGenerateResponseModel {
private String model; private String model;
private @JsonProperty("created_at") String createdAt; private @JsonProperty("created_at") String createdAt;
private String response; private String response;
private String thinking;
private boolean done; private boolean done;
private @JsonProperty("done_reason") String doneReason;
private List<Integer> context; private List<Integer> context;
private @JsonProperty("total_duration") Long totalDuration; private @JsonProperty("total_duration") Long totalDuration;
private @JsonProperty("load_duration") Long loadDuration; private @JsonProperty("load_duration") Long loadDuration;
private @JsonProperty("prompt_eval_duration") Long promptEvalDuration;
private @JsonProperty("eval_duration") Long evalDuration;
private @JsonProperty("prompt_eval_count") Integer promptEvalCount; private @JsonProperty("prompt_eval_count") Integer promptEvalCount;
private @JsonProperty("prompt_eval_duration") Long promptEvalDuration;
private @JsonProperty("eval_count") Integer evalCount; private @JsonProperty("eval_count") Integer evalCount;
private @JsonProperty("eval_duration") Long evalDuration;
} }

View File

@ -5,14 +5,16 @@ import java.util.List;
public class OllamaGenerateStreamObserver { public class OllamaGenerateStreamObserver {
private OllamaStreamHandler streamHandler; private final OllamaStreamHandler thinkingStreamHandler;
private final OllamaStreamHandler responseStreamHandler;
private List<OllamaGenerateResponseModel> responseParts = new ArrayList<>(); private final List<OllamaGenerateResponseModel> responseParts = new ArrayList<>();
private String message = ""; private String message = "";
public OllamaGenerateStreamObserver(OllamaStreamHandler streamHandler) { public OllamaGenerateStreamObserver(OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) {
this.streamHandler = streamHandler; this.responseStreamHandler = responseStreamHandler;
this.thinkingStreamHandler = thinkingStreamHandler;
} }
public void notify(OllamaGenerateResponseModel currentResponsePart) { public void notify(OllamaGenerateResponseModel currentResponsePart) {
@ -21,9 +23,24 @@ public class OllamaGenerateStreamObserver {
} }
protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart) { protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart) {
message = message + currentResponsePart.getResponse(); String response = currentResponsePart.getResponse();
streamHandler.accept(message); String thinking = currentResponsePart.getThinking();
boolean hasResponse = response != null && !response.isEmpty();
boolean hasThinking = thinking != null && !thinking.isEmpty();
if (!hasResponse && hasThinking && thinkingStreamHandler != null) {
// message = message + thinking;
// use only new tokens received, instead of appending the tokens to the previous
// ones and sending the full string again
thinkingStreamHandler.accept(thinking);
} else if (hasResponse && responseStreamHandler != null) {
// message = message + response;
// use only new tokens received, instead of appending the tokens to the previous
// ones and sending the full string again
responseStreamHandler.accept(response);
}
} }
} }

View File

@ -1,13 +1,14 @@
package io.github.ollama4j.models.request; package io.github.ollama4j.models.request;
import java.util.Base64;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.EqualsAndHashCode;
import java.util.Base64;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class BasicAuth extends Auth { public class BasicAuth extends Auth {
private String username; private String username;
private String password; private String password;

View File

@ -2,18 +2,20 @@ package io.github.ollama4j.models.request;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class BearerAuth extends Auth { public class BearerAuth extends Auth {
private String bearerToken; private String bearerToken;
/** /**
* Get authentication header value. * Get authentication header value.
* *
* @return authentication header value with bearer token * @return authentication header value with bearer token
*/ */
public String getAuthHeaderValue() { public String getAuthHeaderValue() {
return "Bearer "+ bearerToken; return "Bearer " + bearerToken;
} }
} }

View File

@ -1,11 +1,11 @@
package io.github.ollama4j.models.request; package io.github.ollama4j.models.request;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
public class CustomModelFileContentsRequest { public class CustomModelFileContentsRequest {

View File

@ -1,11 +1,11 @@
package io.github.ollama4j.models.request; package io.github.ollama4j.models.request;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
public class CustomModelFilePathRequest { public class CustomModelFilePathRequest {

View File

@ -1,17 +1,15 @@
package io.github.ollama4j.models.request; package io.github.ollama4j.models.request;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Data;
import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
@Data @Data
@AllArgsConstructor @AllArgsConstructor

View File

@ -1,11 +1,11 @@
package io.github.ollama4j.models.request; package io.github.ollama4j.models.request;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
public class ModelRequest { public class ModelRequest {

View File

@ -24,6 +24,7 @@ import java.util.List;
/** /**
* Specialization class for requests * Specialization class for requests
*/ */
@SuppressWarnings("resource")
public class OllamaChatEndpointCaller extends OllamaEndpointCaller { public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
@ -46,19 +47,24 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
* in case the JSON Object cannot be parsed to a {@link OllamaChatResponseModel}. Thus, the ResponseModel should * in case the JSON Object cannot be parsed to a {@link OllamaChatResponseModel}. Thus, the ResponseModel should
* never be null. * never be null.
* *
* @param line streamed line of ollama stream response * @param line streamed line of ollama stream response
* @param responseBuffer Stringbuffer to add latest response message part to * @param responseBuffer Stringbuffer to add latest response message part to
* @return TRUE, if ollama-Response has 'done' state * @return TRUE, if ollama-Response has 'done' state
*/ */
@Override @Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) {
try { try {
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
// it seems that under heavy load ollama responds with an empty chat message part in the streamed response // it seems that under heavy load ollama responds with an empty chat message part in the streamed response
// thus, we null check the message and hope that the next streamed response has some message content again // thus, we null check the message and hope that the next streamed response has some message content again
OllamaChatMessage message = ollamaResponseModel.getMessage(); OllamaChatMessage message = ollamaResponseModel.getMessage();
if(message != null) { if (message != null) {
responseBuffer.append(message.getContent()); if (message.getThinking() != null) {
thinkingBuffer.append(message.getThinking());
}
else {
responseBuffer.append(message.getContent());
}
if (tokenHandler != null) { if (tokenHandler != null) {
tokenHandler.accept(ollamaResponseModel); tokenHandler.accept(ollamaResponseModel);
} }
@ -85,13 +91,14 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
.POST( .POST(
body.getBodyPublisher()); body.getBodyPublisher());
HttpRequest request = requestBuilder.build(); HttpRequest request = requestBuilder.build();
if (isVerbose()) LOG.info("Asking model: " + body); if (isVerbose()) LOG.info("Asking model: {}", body);
HttpResponse<InputStream> response = HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
InputStream responseBodyStream = response.body(); InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder(); StringBuilder responseBuffer = new StringBuilder();
StringBuilder thinkingBuffer = new StringBuilder();
OllamaChatResponseModel ollamaChatResponseModel = null; OllamaChatResponseModel ollamaChatResponseModel = null;
List<OllamaChatToolCalls> wantedToolsForStream = null; List<OllamaChatToolCalls> wantedToolsForStream = null;
try (BufferedReader reader = try (BufferedReader reader =
@ -115,14 +122,20 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponse.class); OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 500) {
LOG.warn("Status code: 500 (Internal Server Error)");
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError());
} else { } else {
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); boolean finished = parseResponseAndAddToBuffer(line, responseBuffer, thinkingBuffer);
ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){ if (body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null) {
wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls(); wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
} }
if (finished && body.stream) { if (finished && body.stream) {
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString()); ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
ollamaChatResponseModel.getMessage().setThinking(thinkingBuffer.toString());
break; break;
} }
} }
@ -132,11 +145,11 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
LOG.error("Status code " + statusCode); LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString()); throw new OllamaBaseException(responseBuffer.toString());
} else { } else {
if(wantedToolsForStream != null) { if (wantedToolsForStream != null) {
ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream); ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
} }
OllamaChatResult ollamaResult = OllamaChatResult ollamaResult =
new OllamaChatResult(ollamaChatResponseModel,body.getMessages()); new OllamaChatResult(ollamaChatResponseModel, body.getMessages());
if (isVerbose()) LOG.info("Model response: " + ollamaResult); if (isVerbose()) LOG.info("Model response: " + ollamaResult);
return ollamaResult; return ollamaResult;
} }

View File

@ -1,15 +1,15 @@
package io.github.ollama4j.models.request; package io.github.ollama4j.models.request;
import java.util.Map;
import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.github.ollama4j.utils.BooleanToJsonFormatFlagSerializer; import io.github.ollama4j.utils.BooleanToJsonFormatFlagSerializer;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import lombok.Data; import lombok.Data;
import java.util.Map;
@Data @Data
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
public abstract class OllamaCommonRequest { public abstract class OllamaCommonRequest {

View File

@ -1,15 +1,15 @@
package io.github.ollama4j.models.request; package io.github.ollama4j.models.request;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.utils.Constants;
import lombok.Getter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.URI; import java.net.URI;
import java.net.http.HttpRequest; import java.net.http.HttpRequest;
import java.time.Duration; import java.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.github.ollama4j.OllamaAPI;
import lombok.Getter;
/** /**
* Abstract helperclass to call the ollama api server. * Abstract helperclass to call the ollama api server.
*/ */
@ -32,7 +32,7 @@ public abstract class OllamaEndpointCaller {
protected abstract String getEndpointSuffix(); protected abstract String getEndpointSuffix();
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer);
/** /**
@ -44,7 +44,7 @@ public abstract class OllamaEndpointCaller {
protected HttpRequest.Builder getRequestBuilderDefault(URI uri) { protected HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder = HttpRequest.Builder requestBuilder =
HttpRequest.newBuilder(uri) HttpRequest.newBuilder(uri)
.header("Content-Type", "application/json") .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.timeout(Duration.ofSeconds(this.requestTimeoutSeconds)); .timeout(Duration.ofSeconds(this.requestTimeoutSeconds));
if (isAuthCredentialsSet()) { if (isAuthCredentialsSet()) {
requestBuilder.header("Authorization", this.auth.getAuthHeaderValue()); requestBuilder.header("Authorization", this.auth.getAuthHeaderValue());

View File

@ -2,11 +2,11 @@ package io.github.ollama4j.models.request;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.response.OllamaErrorResponse;
import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver; import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaStreamHandler;
import io.github.ollama4j.models.response.OllamaErrorResponse;
import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.utils.OllamaRequestBody; import io.github.ollama4j.utils.OllamaRequestBody;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -22,11 +22,12 @@ import java.net.http.HttpRequest;
import java.net.http.HttpResponse; import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@SuppressWarnings("resource")
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
private OllamaGenerateStreamObserver streamObserver; private OllamaGenerateStreamObserver responseStreamObserver;
public OllamaGenerateEndpointCaller(String host, Auth basicAuth, long requestTimeoutSeconds, boolean verbose) { public OllamaGenerateEndpointCaller(String host, Auth basicAuth, long requestTimeoutSeconds, boolean verbose) {
super(host, basicAuth, requestTimeoutSeconds, verbose); super(host, basicAuth, requestTimeoutSeconds, verbose);
@ -38,12 +39,17 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
} }
@Override @Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) {
try { try {
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
responseBuffer.append(ollamaResponseModel.getResponse()); if (ollamaResponseModel.getResponse() != null) {
if (streamObserver != null) { responseBuffer.append(ollamaResponseModel.getResponse());
streamObserver.notify(ollamaResponseModel); }
if (ollamaResponseModel.getThinking() != null) {
thinkingBuffer.append(ollamaResponseModel.getThinking());
}
if (responseStreamObserver != null) {
responseStreamObserver.notify(ollamaResponseModel);
} }
return ollamaResponseModel.isDone(); return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
@ -52,9 +58,8 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
} }
} }
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) throws OllamaBaseException, IOException, InterruptedException {
throws OllamaBaseException, IOException, InterruptedException { responseStreamObserver = new OllamaGenerateStreamObserver(thinkingStreamHandler, responseStreamHandler);
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
return callSync(body); return callSync(body);
} }
@ -67,46 +72,41 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
* @throws IOException in case the responseStream can not be read * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen * @throws InterruptedException in case the server is not reachable or network issues happen
*/ */
@SuppressWarnings("DuplicatedCode")
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException { public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
// Create Request // Create Request
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(getHost() + getEndpointSuffix()); URI uri = URI.create(getHost() + getEndpointSuffix());
HttpRequest.Builder requestBuilder = HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).POST(body.getBodyPublisher());
getRequestBuilderDefault(uri)
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build(); HttpRequest request = requestBuilder.build();
if (isVerbose()) LOG.info("Asking model: " + body.toString()); if (isVerbose()) LOG.info("Asking model: {}", body);
HttpResponse<InputStream> response = HttpResponse<InputStream> response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
InputStream responseBodyStream = response.body(); InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder(); StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader = StringBuilder thinkingBuffer = new StringBuilder();
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { OllamaGenerateResponseModel ollamaGenerateResponseModel = null;
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
if (statusCode == 404) { if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)"); LOG.warn("Status code: 404 (Not Found)");
OllamaErrorResponse ollamaResponseModel = OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) { } else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)"); LOG.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponse ollamaResponseModel = OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) { } else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)"); LOG.warn("Status code: 400 (Bad Request)");
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
OllamaErrorResponse.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else { } else {
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); boolean finished = parseResponseAndAddToBuffer(line, responseBuffer, thinkingBuffer);
if (finished) { if (finished) {
ollamaGenerateResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
break; break;
} }
} }
@ -114,13 +114,25 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
} }
if (statusCode != 200) { if (statusCode != 200) {
LOG.error("Status code " + statusCode); LOG.error("Status code: {}", statusCode);
throw new OllamaBaseException(responseBuffer.toString()); throw new OllamaBaseException(responseBuffer.toString());
} else { } else {
long endTime = System.currentTimeMillis(); long endTime = System.currentTimeMillis();
OllamaResult ollamaResult = OllamaResult ollamaResult = new OllamaResult(responseBuffer.toString(), thinkingBuffer.toString(), endTime - startTime, statusCode);
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
if (isVerbose()) LOG.info("Model response: " + ollamaResult); ollamaResult.setModel(ollamaGenerateResponseModel.getModel());
ollamaResult.setCreatedAt(ollamaGenerateResponseModel.getCreatedAt());
ollamaResult.setDone(ollamaGenerateResponseModel.isDone());
ollamaResult.setDoneReason(ollamaGenerateResponseModel.getDoneReason());
ollamaResult.setContext(ollamaGenerateResponseModel.getContext());
ollamaResult.setTotalDuration(ollamaGenerateResponseModel.getTotalDuration());
ollamaResult.setLoadDuration(ollamaGenerateResponseModel.getLoadDuration());
ollamaResult.setPromptEvalCount(ollamaGenerateResponseModel.getPromptEvalCount());
ollamaResult.setPromptEvalDuration(ollamaGenerateResponseModel.getPromptEvalDuration());
ollamaResult.setEvalCount(ollamaGenerateResponseModel.getEvalCount());
ollamaResult.setEvalDuration(ollamaGenerateResponseModel.getEvalDuration());
if (isVerbose()) LOG.info("Model response: {}", ollamaResult);
return ollamaResult; return ollamaResult;
} }
} }

View File

@ -1,9 +1,10 @@
package io.github.ollama4j.models.response; package io.github.ollama4j.models.response;
import java.util.ArrayList;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.ArrayList;
import java.util.List;
@Data @Data
public class LibraryModel { public class LibraryModel {

View File

@ -2,8 +2,6 @@ package io.github.ollama4j.models.response;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class LibraryModelTag { public class LibraryModelTag {
private String name; private String name;

View File

@ -1,9 +1,9 @@
package io.github.ollama4j.models.response; package io.github.ollama4j.models.response;
import java.util.List;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class ListModelsResponse { public class ListModelsResponse {
private List<Model> models; private List<Model> models;

View File

@ -1,13 +1,13 @@
package io.github.ollama4j.models.response; package io.github.ollama4j.models.response;
import java.time.OffsetDateTime;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import lombok.Data; import lombok.Data;
import java.time.OffsetDateTime;
@Data @Data
@JsonIgnoreProperties(ignoreUnknown = true) @JsonIgnoreProperties(ignoreUnknown = true)
public class Model { public class Model {

View File

@ -3,6 +3,7 @@ package io.github.ollama4j.models.response;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.models.generate.OllamaGenerateRequest; import io.github.ollama4j.models.generate.OllamaGenerateRequest;
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
import io.github.ollama4j.utils.Constants;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
@ -25,8 +26,10 @@ import java.time.Duration;
public class OllamaAsyncResultStreamer extends Thread { public class OllamaAsyncResultStreamer extends Thread {
private final HttpRequest.Builder requestBuilder; private final HttpRequest.Builder requestBuilder;
private final OllamaGenerateRequest ollamaRequestModel; private final OllamaGenerateRequest ollamaRequestModel;
private final OllamaResultStream stream = new OllamaResultStream(); private final OllamaResultStream thinkingResponseStream = new OllamaResultStream();
private final OllamaResultStream responseStream = new OllamaResultStream();
private String completeResponse; private String completeResponse;
private String completeThinkingResponse;
/** /**
@ -53,14 +56,11 @@ public class OllamaAsyncResultStreamer extends Thread {
@Getter @Getter
private long responseTime = 0; private long responseTime = 0;
public OllamaAsyncResultStreamer( public OllamaAsyncResultStreamer(HttpRequest.Builder requestBuilder, OllamaGenerateRequest ollamaRequestModel, long requestTimeoutSeconds) {
HttpRequest.Builder requestBuilder,
OllamaGenerateRequest ollamaRequestModel,
long requestTimeoutSeconds) {
this.requestBuilder = requestBuilder; this.requestBuilder = requestBuilder;
this.ollamaRequestModel = ollamaRequestModel; this.ollamaRequestModel = ollamaRequestModel;
this.completeResponse = ""; this.completeResponse = "";
this.stream.add(""); this.responseStream.add("");
this.requestTimeoutSeconds = requestTimeoutSeconds; this.requestTimeoutSeconds = requestTimeoutSeconds;
} }
@ -68,47 +68,63 @@ public class OllamaAsyncResultStreamer extends Thread {
public void run() { public void run() {
ollamaRequestModel.setStream(true); ollamaRequestModel.setStream(true);
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
long startTime = System.currentTimeMillis();
try { try {
long startTime = System.currentTimeMillis(); HttpRequest request = requestBuilder.POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).timeout(Duration.ofSeconds(requestTimeoutSeconds)).build();
HttpRequest request = HttpResponse<InputStream> response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
requestBuilder
.POST(
HttpRequest.BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
.build();
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
this.httpStatusCode = statusCode; this.httpStatusCode = statusCode;
InputStream responseBodyStream = response.body(); InputStream responseBodyStream = response.body();
try (BufferedReader reader = BufferedReader reader = null;
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { try {
reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8));
String line; String line;
StringBuilder thinkingBuffer = new StringBuilder();
StringBuilder responseBuffer = new StringBuilder(); StringBuilder responseBuffer = new StringBuilder();
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
if (statusCode == 404) { if (statusCode == 404) {
OllamaErrorResponse ollamaResponseModel = OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); responseStream.add(ollamaResponseModel.getError());
stream.add(ollamaResponseModel.getError());
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else { } else {
OllamaGenerateResponseModel ollamaResponseModel = OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); String thinkingTokens = ollamaResponseModel.getThinking();
String res = ollamaResponseModel.getResponse(); String responseTokens = ollamaResponseModel.getResponse();
stream.add(res); if (thinkingTokens == null) {
thinkingTokens = "";
}
if (responseTokens == null) {
responseTokens = "";
}
thinkingResponseStream.add(thinkingTokens);
responseStream.add(responseTokens);
if (!ollamaResponseModel.isDone()) { if (!ollamaResponseModel.isDone()) {
responseBuffer.append(res); responseBuffer.append(responseTokens);
thinkingBuffer.append(thinkingTokens);
} }
} }
} }
this.succeeded = true; this.succeeded = true;
this.completeThinkingResponse = thinkingBuffer.toString();
this.completeResponse = responseBuffer.toString(); this.completeResponse = responseBuffer.toString();
long endTime = System.currentTimeMillis(); long endTime = System.currentTimeMillis();
responseTime = endTime - startTime; responseTime = endTime - startTime;
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e) {
// Optionally log or handle
}
}
if (responseBodyStream != null) {
try {
responseBodyStream.close();
} catch (IOException e) {
// Optionally log or handle
}
}
} }
if (statusCode != 200) { if (statusCode != 200) {
throw new OllamaBaseException(this.completeResponse); throw new OllamaBaseException(this.completeResponse);

View File

@ -3,116 +3,136 @@ package io.github.ollama4j.models.response;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import lombok.Data; import lombok.Data;
import lombok.Getter; import lombok.Getter;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
/** The type Ollama result. */ import static io.github.ollama4j.utils.Utils.getObjectMapper;
/**
* The type Ollama result.
*/
@Getter @Getter
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Data @Data
@JsonIgnoreProperties(ignoreUnknown = true) @JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaResult { public class OllamaResult {
/** /**
* -- GETTER -- * Get the completion/response text
* Get the completion/response text */
* private final String response;
* @return String completion/response text /**
*/ * Get the thinking text (if available)
private final String response; */
private final String thinking;
/**
* Get the response status code.
*/
private int httpStatusCode;
/**
* Get the response time in milliseconds.
*/
private long responseTime = 0;
/** private String model;
* -- GETTER -- private String createdAt;
* Get the response status code. private boolean done;
* private String doneReason;
* @return int - response status code private List<Integer> context;
*/ private Long totalDuration;
private int httpStatusCode; private Long loadDuration;
private Integer promptEvalCount;
private Long promptEvalDuration;
private Integer evalCount;
private Long evalDuration;
/** public OllamaResult(String response, String thinking, long responseTime, int httpStatusCode) {
* -- GETTER -- this.response = response;
* Get the response time in milliseconds. this.thinking = thinking;
* this.responseTime = responseTime;
* @return long - response time in milliseconds this.httpStatusCode = httpStatusCode;
*/
private long responseTime = 0;
public OllamaResult(String response, long responseTime, int httpStatusCode) {
this.response = response;
this.responseTime = responseTime;
this.httpStatusCode = httpStatusCode;
}
@Override
public String toString() {
try {
Map<String, Object> responseMap = new HashMap<>();
responseMap.put("response", this.response);
responseMap.put("httpStatusCode", this.httpStatusCode);
responseMap.put("responseTime", this.responseTime);
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* Get the structured response if the response is a JSON object.
*
* @return Map - structured response
* @throws IllegalArgumentException if the response is not a valid JSON object
*/
public Map<String, Object> getStructuredResponse() {
String responseStr = this.getResponse();
if (responseStr == null || responseStr.trim().isEmpty()) {
throw new IllegalArgumentException("Response is empty or null");
} }
try { @Override
// Check if the response is a valid JSON public String toString() {
if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || try {
(!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { Map<String, Object> responseMap = new HashMap<>();
throw new IllegalArgumentException("Response is not a valid JSON object"); responseMap.put("response", this.response);
} responseMap.put("thinking", this.thinking);
responseMap.put("httpStatusCode", this.httpStatusCode);
Map<String, Object> response = getObjectMapper().readValue(responseStr, responseMap.put("responseTime", this.responseTime);
new TypeReference<Map<String, Object>>() { responseMap.put("model", this.model);
}); responseMap.put("createdAt", this.createdAt);
return response; responseMap.put("done", this.done);
} catch (JsonProcessingException e) { responseMap.put("doneReason", this.doneReason);
throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); responseMap.put("context", this.context);
} responseMap.put("totalDuration", this.totalDuration);
} responseMap.put("loadDuration", this.loadDuration);
responseMap.put("promptEvalCount", this.promptEvalCount);
/** responseMap.put("promptEvalDuration", this.promptEvalDuration);
* Get the structured response mapped to a specific class type. responseMap.put("evalCount", this.evalCount);
* responseMap.put("evalDuration", this.evalDuration);
* @param <T> The type of class to map the response to return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap);
* @param clazz The class to map the response to } catch (JsonProcessingException e) {
* @return An instance of the specified class with the response data throw new RuntimeException(e);
* @throws IllegalArgumentException if the response is not a valid JSON or is empty }
* @throws RuntimeException if there is an error mapping the response
*/
public <T> T as(Class<T> clazz) {
String responseStr = this.getResponse();
if (responseStr == null || responseStr.trim().isEmpty()) {
throw new IllegalArgumentException("Response is empty or null");
} }
try { /**
// Check if the response is a valid JSON * Get the structured response if the response is a JSON object.
if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || *
(!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { * @return Map - structured response
throw new IllegalArgumentException("Response is not a valid JSON object"); * @throws IllegalArgumentException if the response is not a valid JSON object
} */
return getObjectMapper().readValue(responseStr, clazz); public Map<String, Object> getStructuredResponse() {
} catch (JsonProcessingException e) { String responseStr = this.getResponse();
throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); if (responseStr == null || responseStr.trim().isEmpty()) {
throw new IllegalArgumentException("Response is empty or null");
}
try {
// Check if the response is a valid JSON
if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) ||
(!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) {
throw new IllegalArgumentException("Response is not a valid JSON object");
}
Map<String, Object> response = getObjectMapper().readValue(responseStr,
new TypeReference<Map<String, Object>>() {
});
return response;
} catch (JsonProcessingException e) {
throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e);
}
}
/**
* Get the structured response mapped to a specific class type.
*
* @param <T> The type of class to map the response to
* @param clazz The class to map the response to
* @return An instance of the specified class with the response data
* @throws IllegalArgumentException if the response is not a valid JSON or is empty
* @throws RuntimeException if there is an error mapping the response
*/
public <T> T as(Class<T> clazz) {
String responseStr = this.getResponse();
if (responseStr == null || responseStr.trim().isEmpty()) {
throw new IllegalArgumentException("Response is empty or null");
}
try {
// Check if the response is a valid JSON
if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) ||
(!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) {
throw new IllegalArgumentException("Response is not a valid JSON object");
}
return getObjectMapper().readValue(responseStr, clazz);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e);
}
} }
}
} }

View File

@ -1,19 +1,18 @@
package io.github.ollama4j.models.response; package io.github.ollama4j.models.response;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
import java.util.Map;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import lombok.Data; import lombok.Data;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.util.List;
import java.util.Map;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
@Getter @Getter
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Data @Data
@ -21,13 +20,22 @@ import lombok.NoArgsConstructor;
@JsonIgnoreProperties(ignoreUnknown = true) @JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaStructuredResult { public class OllamaStructuredResult {
private String response; private String response;
private String thinking;
private int httpStatusCode; private int httpStatusCode;
private long responseTime = 0; private long responseTime = 0;
private String model; private String model;
private @JsonProperty("created_at") String createdAt;
private boolean done;
private @JsonProperty("done_reason") String doneReason;
private List<Integer> context;
private @JsonProperty("total_duration") Long totalDuration;
private @JsonProperty("load_duration") Long loadDuration;
private @JsonProperty("prompt_eval_count") Integer promptEvalCount;
private @JsonProperty("prompt_eval_duration") Long promptEvalDuration;
private @JsonProperty("eval_count") Integer evalCount;
private @JsonProperty("eval_duration") Long evalDuration;
public OllamaStructuredResult(String response, long responseTime, int httpStatusCode) { public OllamaStructuredResult(String response, long responseTime, int httpStatusCode) {
this.response = response; this.response = response;
this.responseTime = responseTime; this.responseTime = responseTime;

View File

@ -2,8 +2,6 @@ package io.github.ollama4j.models.response;
import lombok.Data; import lombok.Data;
import java.util.List;
@Data @Data
public class OllamaVersion { public class OllamaVersion {
private String version; private String version;

View File

@ -9,14 +9,21 @@ public class ToolRegistry {
public ToolFunction getToolFunction(String name) { public ToolFunction getToolFunction(String name) {
final Tools.ToolSpecification toolSpecification = tools.get(name); final Tools.ToolSpecification toolSpecification = tools.get(name);
return toolSpecification !=null ? toolSpecification.getToolFunction() : null ; return toolSpecification != null ? toolSpecification.getToolFunction() : null;
} }
public void addTool (String name, Tools.ToolSpecification specification) { public void addTool(String name, Tools.ToolSpecification specification) {
tools.put(name, specification); tools.put(name, specification);
} }
public Collection<Tools.ToolSpecification> getRegisteredSpecs(){ public Collection<Tools.ToolSpecification> getRegisteredSpecs() {
return tools.values(); return tools.values();
} }
/**
* Removes all registered tools from the registry.
*/
public void clear() {
tools.clear();
}
} }

View File

@ -1,88 +1,54 @@
package io.github.ollama4j.tools.sampletools; package io.github.ollama4j.tools.sampletools;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.Map;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.ollama4j.tools.Tools; import io.github.ollama4j.tools.Tools;
import java.util.Map;
@SuppressWarnings("resource")
public class WeatherTool { public class WeatherTool {
private String openWeatherMapAPIKey = null; private String paramCityName = "cityName";
public WeatherTool(String openWeatherMapAPIKey) { public WeatherTool() {
this.openWeatherMapAPIKey = openWeatherMapAPIKey; }
}
public String getCurrentWeather(Map<String, Object> arguments) { public String getCurrentWeather(Map<String, Object> arguments) {
String city = (String) arguments.get("cityName"); String city = (String) arguments.get(paramCityName);
System.out.println("Finding weather for city: " + city); return "It is sunny in " + city;
}
String url = String.format("https://api.openweathermap.org/data/2.5/weather?q=%s&appid=%s&units=metric", public Tools.ToolSpecification getSpecification() {
city, return Tools.ToolSpecification.builder()
this.openWeatherMapAPIKey); .functionName("weather-reporter")
.functionDescription(
HttpClient client = HttpClient.newHttpClient(); "You are a tool who simply finds the city name from the user's message input/query about weather.")
HttpRequest request = HttpRequest.newBuilder() .toolFunction(this::getCurrentWeather)
.uri(URI.create(url)) .toolPrompt(
.build(); Tools.PromptFuncDefinition.builder()
try { .type("prompt")
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); .function(
if (response.statusCode() == 200) { Tools.PromptFuncDefinition.PromptFuncSpec
ObjectMapper mapper = new ObjectMapper(); .builder()
JsonNode root = mapper.readTree(response.body()); .name("get-city-name")
JsonNode main = root.path("main"); .description("Get the city name")
double temperature = main.path("temp").asDouble(); .parameters(
String description = root.path("weather").get(0).path("description").asText(); Tools.PromptFuncDefinition.Parameters
return String.format("Weather in %s: %.1f°C, %s", city, temperature, description); .builder()
} else { .type("object")
return "Could not retrieve weather data for " + city + ". Status code: " .properties(
+ response.statusCode(); Map.of(
} paramCityName,
} catch (IOException | InterruptedException e) { Tools.PromptFuncDefinition.Property
e.printStackTrace(); .builder()
return "Error retrieving weather data: " + e.getMessage(); .type("string")
} .description(
} "The name of the city. e.g. Bengaluru")
.required(true)
public Tools.ToolSpecification getSpecification() { .build()))
return Tools.ToolSpecification.builder() .required(java.util.List
.functionName("weather-reporter") .of(paramCityName))
.functionDescription(
"You are a tool who simply finds the city name from the user's message input/query about weather.")
.toolFunction(this::getCurrentWeather)
.toolPrompt(
Tools.PromptFuncDefinition.builder()
.type("prompt")
.function(
Tools.PromptFuncDefinition.PromptFuncSpec
.builder()
.name("get-city-name")
.description("Get the city name")
.parameters(
Tools.PromptFuncDefinition.Parameters
.builder()
.type("object")
.properties(
Map.of(
"cityName",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description(
"The name of the city. e.g. Bengaluru")
.required(true)
.build()))
.required(java.util.List
.of("cityName"))
.build())
.build())
.build()) .build())
.build(); .build())
} .build())
.build();
}
} }

View File

@ -1,11 +1,11 @@
package io.github.ollama4j.utils; package io.github.ollama4j.utils;
import java.io.IOException;
import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider; import com.fasterxml.jackson.databind.SerializerProvider;
import java.io.IOException;
public class BooleanToJsonFormatFlagSerializer extends JsonSerializer<Boolean>{ public class BooleanToJsonFormatFlagSerializer extends JsonSerializer<Boolean>{
@Override @Override

View File

@ -0,0 +1,14 @@
package io.github.ollama4j.utils;
public final class Constants {
public static final class HttpConstants {
private HttpConstants() {
}
public static final String APPLICATION_JSON = "application/json";
public static final String APPLICATION_XML = "application/xml";
public static final String TEXT_PLAIN = "text/plain";
public static final String HEADER_KEY_CONTENT_TYPE = "Content-Type";
public static final String HEADER_KEY_ACCEPT = "Accept";
}
}

View File

@ -1,13 +1,13 @@
package io.github.ollama4j.utils; package io.github.ollama4j.utils;
import java.io.IOException;
import java.util.Base64;
import java.util.Collection;
import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider; import com.fasterxml.jackson.databind.SerializerProvider;
import java.io.IOException;
import java.util.Base64;
import java.util.Collection;
public class FileToBase64Serializer extends JsonSerializer<Collection<byte[]>> { public class FileToBase64Serializer extends JsonSerializer<Collection<byte[]>> {
@Override @Override

View File

@ -1,11 +1,11 @@
package io.github.ollama4j.utils; package io.github.ollama4j.utils;
import java.net.http.HttpRequest.BodyPublisher;
import java.net.http.HttpRequest.BodyPublishers;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import java.net.http.HttpRequest.BodyPublisher;
import java.net.http.HttpRequest.BodyPublishers;
/** /**
* Interface to represent a OllamaRequest as HTTP-Request Body via {@link BodyPublishers}. * Interface to represent a OllamaRequest as HTTP-Request Body via {@link BodyPublishers}.
*/ */

View File

@ -1,8 +1,9 @@
package io.github.ollama4j.utils; package io.github.ollama4j.utils;
import java.util.Map;
import lombok.Data; import lombok.Data;
import java.util.Map;
/** Class for options for Ollama model. */ /** Class for options for Ollama model. */
@Data @Data
public class Options { public class Options {

View File

@ -1,6 +1,5 @@
package io.github.ollama4j.utils; package io.github.ollama4j.utils;
import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
/** Builder class for creating options for Ollama model. */ /** Builder class for creating options for Ollama model. */

View File

@ -1,25 +0,0 @@
package io.github.ollama4j.utils;
import io.github.ollama4j.OllamaAPI;
import java.io.InputStream;
import java.util.Scanner;
public class SamplePrompts {
public static String getSampleDatabasePromptWithQuestion(String question) throws Exception {
ClassLoader classLoader = OllamaAPI.class.getClassLoader();
InputStream inputStream = classLoader.getResourceAsStream("sample-db-prompt-template.txt");
if (inputStream != null) {
Scanner scanner = new Scanner(inputStream);
StringBuilder stringBuffer = new StringBuilder();
while (scanner.hasNextLine()) {
stringBuffer.append(scanner.nextLine()).append("\n");
}
scanner.close();
return stringBuffer.toString().replaceAll("<question>", question);
} else {
throw new Exception("Sample database question file not found.");
}
}
}

View File

@ -1,38 +1,45 @@
package io.github.ollama4j.utils; package io.github.ollama4j.utils;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.net.URL; import java.net.URL;
import java.util.Objects;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
public class Utils { public class Utils {
private static ObjectMapper objectMapper; private static ObjectMapper objectMapper;
public static ObjectMapper getObjectMapper() { public static ObjectMapper getObjectMapper() {
if(objectMapper == null) { if (objectMapper == null) {
objectMapper = new ObjectMapper(); objectMapper = new ObjectMapper();
objectMapper.registerModule(new JavaTimeModule()); objectMapper.registerModule(new JavaTimeModule());
}
return objectMapper;
} }
return objectMapper;
}
public static byte[] loadImageBytesFromUrl(String imageUrl) public static byte[] loadImageBytesFromUrl(String imageUrl)
throws IOException, URISyntaxException { throws IOException, URISyntaxException {
URL url = new URI(imageUrl).toURL(); URL url = new URI(imageUrl).toURL();
try (InputStream in = url.openStream(); try (InputStream in = url.openStream();
ByteArrayOutputStream out = new ByteArrayOutputStream()) { ByteArrayOutputStream out = new ByteArrayOutputStream()) {
byte[] buffer = new byte[1024]; byte[] buffer = new byte[1024];
int bytesRead; int bytesRead;
while ((bytesRead = in.read(buffer)) != -1) { while ((bytesRead = in.read(buffer)) != -1) {
out.write(buffer, 0, bytesRead); out.write(buffer, 0, bytesRead);
} }
return out.toByteArray(); return out.toByteArray();
}
}
public static File getFileFromClasspath(String fileName) {
ClassLoader classLoader = Utils.class.getClassLoader();
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
} }
}
} }

View File

@ -1,61 +0,0 @@
"""
Following is the database schema.
DROP TABLE IF EXISTS product_categories;
CREATE TABLE IF NOT EXISTS product_categories
(
category_id INTEGER PRIMARY KEY, -- Unique ID for each category
name VARCHAR(50), -- Name of the category
parent INTEGER NULL, -- Parent category - for hierarchical categories
FOREIGN KEY (parent) REFERENCES product_categories (category_id)
);
DROP TABLE IF EXISTS products;
CREATE TABLE IF NOT EXISTS products
(
product_id INTEGER PRIMARY KEY, -- Unique ID for each product
name VARCHAR(50), -- Name of the product
price DECIMAL(10, 2), -- Price of each unit of the product
quantity INTEGER, -- Current quantity in stock
category_id INTEGER, -- Unique ID for each product
FOREIGN KEY (category_id) REFERENCES product_categories (category_id)
);
DROP TABLE IF EXISTS customers;
CREATE TABLE IF NOT EXISTS customers
(
customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
name VARCHAR(50), -- Name of the customer
address VARCHAR(100) -- Mailing address of the customer
);
DROP TABLE IF EXISTS salespeople;
CREATE TABLE IF NOT EXISTS salespeople
(
salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
name VARCHAR(50), -- Name of the salesperson
region VARCHAR(50) -- Geographic sales region
);
DROP TABLE IF EXISTS sales;
CREATE TABLE IF NOT EXISTS sales
(
sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
product_id INTEGER, -- ID of product sold
customer_id INTEGER, -- ID of customer who made the purchase
salesperson_id INTEGER, -- ID of salesperson who made the sale
sale_date DATE, -- Date the sale occurred
quantity INTEGER, -- Quantity of product sold
FOREIGN KEY (product_id) REFERENCES products (product_id),
FOREIGN KEY (customer_id) REFERENCES customers (customer_id),
FOREIGN KEY (salesperson_id) REFERENCES salespeople (salesperson_id)
);
DROP TABLE IF EXISTS product_suppliers;
CREATE TABLE IF NOT EXISTS product_suppliers
(
supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
product_id INTEGER, -- Product ID supplied
supply_price DECIMAL(10, 2), -- Unit price charged by supplier
FOREIGN KEY (product_id) REFERENCES products (product_id)
);
Generate only a valid (syntactically correct) executable Postgres SQL query (without any explanation of the query) for the following question:
`<question>`:
"""

View File

@ -1,6 +1,5 @@
package io.github.ollama4j.integrationtests; package io.github.ollama4j.integrationtests;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.exceptions.ToolInvocationException; import io.github.ollama4j.exceptions.ToolInvocationException;
@ -16,9 +15,6 @@ import io.github.ollama4j.tools.ToolFunction;
import io.github.ollama4j.tools.Tools; import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.tools.annotations.OllamaToolService; import io.github.ollama4j.tools.annotations.OllamaToolService;
import io.github.ollama4j.utils.OptionsBuilder; import io.github.ollama4j.utils.OptionsBuilder;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; import org.junit.jupiter.api.MethodOrderer.OrderAnnotation;
import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Order;
@ -40,24 +36,24 @@ import static org.junit.jupiter.api.Assertions.*;
@TestMethodOrder(OrderAnnotation.class) @TestMethodOrder(OrderAnnotation.class)
@SuppressWarnings({"HttpUrlsUsage", "SpellCheckingInspection"}) @SuppressWarnings({"HttpUrlsUsage", "SpellCheckingInspection"})
public class OllamaAPIIntegrationTest { class OllamaAPIIntegrationTest {
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPIIntegrationTest.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaAPIIntegrationTest.class);
private static OllamaContainer ollama; private static OllamaContainer ollama;
private static OllamaAPI api; private static OllamaAPI api;
private static final String EMBEDDING_MODEL_MINILM = "all-minilm"; private static final String EMBEDDING_MODEL = "all-minilm";
private static final String CHAT_MODEL_QWEN_SMALL = "qwen2.5:0.5b"; private static final String VISION_MODEL = "moondream:1.8b";
private static final String CHAT_MODEL_INSTRUCT = "qwen2.5:0.5b-instruct"; private static final String THINKING_TOOL_MODEL = "gpt-oss:20b";
private static final String CHAT_MODEL_SYSTEM_PROMPT = "llama3.2:1b"; private static final String GENERAL_PURPOSE_MODEL = "gemma3:270m";
private static final String CHAT_MODEL_LLAMA3 = "llama3"; private static final String TOOLS_MODEL = "mistral:7b";
private static final String IMAGE_MODEL_LLAVA = "llava";
@BeforeAll @BeforeAll
public static void setUp() { static void setUp() {
try { try {
boolean useExternalOllamaHost = Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST")); boolean useExternalOllamaHost = Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST"));
String ollamaHost = System.getenv("OLLAMA_HOST"); String ollamaHost = System.getenv("OLLAMA_HOST");
if (useExternalOllamaHost) { if (useExternalOllamaHost) {
LOG.info("Using external Ollama host..."); LOG.info("Using external Ollama host...");
api = new OllamaAPI(ollamaHost); api = new OllamaAPI(ollamaHost);
@ -80,7 +76,7 @@ public class OllamaAPIIntegrationTest {
} }
api.setRequestTimeoutSeconds(120); api.setRequestTimeoutSeconds(120);
api.setVerbose(true); api.setVerbose(true);
api.setNumberOfRetriesForModelPull(3); api.setNumberOfRetriesForModelPull(5);
} }
@Test @Test
@ -92,7 +88,7 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(1) @Order(1)
public void testVersionAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { void testVersionAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
// String expectedVersion = ollama.getDockerImageName().split(":")[1]; // String expectedVersion = ollama.getDockerImageName().split(":")[1];
String actualVersion = api.getVersion(); String actualVersion = api.getVersion();
assertNotNull(actualVersion); assertNotNull(actualVersion);
@ -100,17 +96,22 @@ public class OllamaAPIIntegrationTest {
// image version"); // image version");
} }
@Test
@Order(1)
void testPing() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
boolean pingResponse = api.ping();
assertTrue(pingResponse, "Ping should return true");
}
@Test @Test
@Order(2) @Order(2)
public void testListModelsAPI() void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
api.pullModel(EMBEDDING_MODEL_MINILM);
// Fetch the list of models // Fetch the list of models
List<Model> models = api.listModels(); List<Model> models = api.listModels();
// Assert that the models list is not null // Assert that the models list is not null
assertNotNull(models, "Models should not be null"); assertNotNull(models, "Models should not be null");
// Assert that models list is either empty or contains more than 0 models // Assert that models list is either empty or contains more than 0 models
assertFalse(models.isEmpty(), "Models list should not be empty"); assertTrue(models.size() >= 0, "Models list should not be empty");
} }
@Test @Test
@ -124,9 +125,8 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(3) @Order(3)
public void testPullModelAPI() void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { api.pullModel(EMBEDDING_MODEL);
api.pullModel(EMBEDDING_MODEL_MINILM);
List<Model> models = api.listModels(); List<Model> models = api.listModels();
assertNotNull(models, "Models should not be null"); assertNotNull(models, "Models should not be null");
assertFalse(models.isEmpty(), "Models list should contain elements"); assertFalse(models.isEmpty(), "Models list should contain elements");
@ -135,17 +135,17 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(4) @Order(4)
void testListModelDetails() throws IOException, OllamaBaseException, URISyntaxException, InterruptedException { void testListModelDetails() throws IOException, OllamaBaseException, URISyntaxException, InterruptedException {
api.pullModel(EMBEDDING_MODEL_MINILM); api.pullModel(EMBEDDING_MODEL);
ModelDetail modelDetails = api.getModelDetails(EMBEDDING_MODEL_MINILM); ModelDetail modelDetails = api.getModelDetails(EMBEDDING_MODEL);
assertNotNull(modelDetails); assertNotNull(modelDetails);
assertTrue(modelDetails.getModelFile().contains(EMBEDDING_MODEL_MINILM)); assertTrue(modelDetails.getModelFile().contains(EMBEDDING_MODEL));
} }
@Test @Test
@Order(5) @Order(5)
public void testEmbeddings() throws Exception { void testEmbeddings() throws Exception {
api.pullModel(EMBEDDING_MODEL_MINILM); api.pullModel(EMBEDDING_MODEL);
OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL_MINILM, OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL,
Arrays.asList("Why is the sky blue?", "Why is the grass green?")); Arrays.asList("Why is the sky blue?", "Why is the grass green?"));
assertNotNull(embeddings, "Embeddings should not be null"); assertNotNull(embeddings, "Embeddings should not be null");
assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty"); assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty");
@ -153,58 +153,44 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(6) @Order(6)
void testAskModelWithStructuredOutput() void testGenerateWithStructuredOutput()
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
api.pullModel(CHAT_MODEL_LLAMA3); api.pullModel(TOOLS_MODEL);
int timeHour = 6; String prompt = "The sun is shining brightly and is directly overhead at the zenith, casting my shadow over my foot, so it must be noon.";
boolean isNightTime = false;
String prompt = "The Sun is shining, and its " + timeHour + ". Its daytime.";
Map<String, Object> format = new HashMap<>(); Map<String, Object> format = new HashMap<>();
format.put("type", "object"); format.put("type", "object");
format.put("properties", new HashMap<String, Object>() { format.put("properties", new HashMap<String, Object>() {
{ {
put("timeHour", new HashMap<String, Object>() { put("isNoon", new HashMap<String, Object>() {
{
put("type", "integer");
}
});
put("isNightTime", new HashMap<String, Object>() {
{ {
put("type", "boolean"); put("type", "boolean");
} }
}); });
} }
}); });
format.put("required", Arrays.asList("timeHour", "isNightTime")); format.put("required", List.of("isNoon"));
OllamaResult result = api.generate(CHAT_MODEL_LLAMA3, prompt, format); OllamaResult result = api.generate(TOOLS_MODEL, prompt, format);
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
assertEquals(timeHour, assertEquals(true, result.getStructuredResponse().get("isNoon"));
result.getStructuredResponse().get("timeHour"));
assertEquals(isNightTime,
result.getStructuredResponse().get("isNightTime"));
TimeOfDay timeOfDay = result.as(TimeOfDay.class);
assertEquals(timeHour, timeOfDay.getTimeHour());
assertEquals(isNightTime, timeOfDay.isNightTime());
} }
@Test @Test
@Order(6) @Order(6)
void testAskModelWithDefaultOptions() void testGennerateModelWithDefaultOptions()
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
api.pullModel(CHAT_MODEL_QWEN_SMALL); api.pullModel(GENERAL_PURPOSE_MODEL);
OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, boolean raw = false;
"What is the capital of France? And what's France's connection with Mona Lisa?", false, boolean thinking = false;
new OptionsBuilder().build()); OllamaResult result = api.generate(GENERAL_PURPOSE_MODEL,
"What is the capital of France? And what's France's connection with Mona Lisa?", raw,
thinking, new OptionsBuilder().build());
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -212,32 +198,31 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(7) @Order(7)
void testAskModelWithDefaultOptionsStreamed() void testGenerateWithDefaultOptionsStreamed()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(CHAT_MODEL_QWEN_SMALL); api.pullModel(GENERAL_PURPOSE_MODEL);
boolean raw = false;
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, OllamaResult result = api.generate(GENERAL_PURPOSE_MODEL,
"What is the capital of France? And what's France's connection with Mona Lisa?", false, "What is the capital of France? And what's France's connection with Mona Lisa?", raw,
new OptionsBuilder().build(), (s) -> { new OptionsBuilder().build(), (s) -> {
LOG.info(s); LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length()); sb.append(s);
LOG.info(substring);
sb.append(substring);
}); });
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim()); assertEquals(sb.toString(), result.getResponse());
} }
@Test @Test
@Order(8) @Order(8)
void testAskModelWithOptions() void testGenerateWithOptions() throws OllamaBaseException, IOException, URISyntaxException,
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { InterruptedException, ToolInvocationException {
api.pullModel(CHAT_MODEL_INSTRUCT); api.pullModel(GENERAL_PURPOSE_MODEL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_INSTRUCT); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GENERAL_PURPOSE_MODEL);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].") "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].")
.build(); .build();
@ -253,29 +238,32 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(9) @Order(9)
void testChatWithSystemPrompt() void testChatWithSystemPrompt() throws OllamaBaseException, IOException, URISyntaxException,
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { InterruptedException, ToolInvocationException {
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); api.pullModel(GENERAL_PURPOSE_MODEL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, String expectedResponse = "Bhai";
"You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!")
.withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?") OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GENERAL_PURPOSE_MODEL);
.withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, String.format(
"[INSTRUCTION-START] You are an obidient and helpful bot named %s. You always answer with only one word and that word is your name. [INSTRUCTION-END]",
expectedResponse)).withMessage(OllamaChatMessageRole.USER, "Who are you?")
.withOptions(new OptionsBuilder().setTemperature(0.0f).build()).build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank()); assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
assertTrue(chatResult.getResponseModel().getMessage().getContent().contains("Shush")); assertTrue(chatResult.getResponseModel().getMessage().getContent().contains(expectedResponse));
assertEquals(3, chatResult.getChatHistory().size()); assertEquals(3, chatResult.getChatHistory().size());
} }
@Test @Test
@Order(10) @Order(10)
public void testChat() throws Exception { void testChat() throws Exception {
api.pullModel(CHAT_MODEL_LLAMA3); api.pullModel(THINKING_TOOL_MODEL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL);
// Create the initial user question // Create the initial user question
OllamaChatRequest requestModel = builder OllamaChatRequest requestModel = builder
@ -288,7 +276,6 @@ public class OllamaAPIIntegrationTest {
assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")),
"Expected chat history to contain '2'"); "Expected chat history to contain '2'");
// Create the next user question: second largest city
requestModel = builder.withMessages(chatResult.getChatHistory()) requestModel = builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); .withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build();
@ -299,10 +286,8 @@ public class OllamaAPIIntegrationTest {
"Expected chat history to contain '4'"); "Expected chat history to contain '4'");
// Create the next user question: the third question // Create the next user question: the third question
requestModel = builder.withMessages(chatResult.getChatHistory()) requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER,
.withMessage(OllamaChatMessageRole.USER, "What is the largest value between 2, 4 and 6?").build();
"What is the largest value between 2, 4 and 6?")
.build();
// Continue conversation with the model for the third question // Continue conversation with the model for the third question
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel);
@ -312,143 +297,103 @@ public class OllamaAPIIntegrationTest {
assertTrue(chatResult.getChatHistory().size() > 2, assertTrue(chatResult.getChatHistory().size() > 2,
"Chat history should contain more than two messages"); "Chat history should contain more than two messages");
assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent() assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent()
.contains("6"), .contains("6"), "Response should contain '6'");
"Response should contain '6'");
}
@Test
@Order(10)
void testChatWithImageFromURL()
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException {
api.pullModel(IMAGE_MODEL_LLAVA);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
Collections.emptyList(),
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build();
api.registerAnnotatedTools(new OllamaAPIIntegrationTest());
OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult);
}
@Test
@Order(10)
void testChatWithImageFromFileWithHistoryRecognition()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
api.pullModel(IMAGE_MODEL_LLAVA);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What's in the picture?",
Collections.emptyList(), List.of(getImageFileFromClasspath("emoji-smile.jpeg")))
.build();
OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
builder.reset();
requestModel = builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the color?").build();
chatResult = api.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
} }
@Test @Test
@Order(11) @Order(11)
void testChatWithExplicitToolDefinition() void testChatWithExplicitToolDefinition() throws OllamaBaseException, IOException, URISyntaxException,
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { InterruptedException, ToolInvocationException {
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); String theToolModel = TOOLS_MODEL;
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); api.pullModel(theToolModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() api.registerTool(employeeFinderTool());
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get employee details from the database")
.parameters(Tools.PromptFuncDefinition.Parameters
.builder().type("object")
.properties(new Tools.PropsBuilder()
.withProperty("employee-name",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description("The name of the employee, e.g. John Doe")
.required(true)
.build())
.withProperty("employee-address",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description(
"The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
.required(true)
.build())
.withProperty("employee-phone",
Tools.PromptFuncDefinition.Property
.builder()
.type("string")
.description(
"The phone number of the employee. Always return a random value. e.g. 9911002233")
.required(true)
.build())
.build())
.required(List.of("employee-name"))
.build())
.build())
.build())
.toolFunction(arguments -> {
// perform DB operations here
return String.format(
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
UUID.randomUUID(), arguments.get("employee-name"),
arguments.get("employee-address"),
arguments.get("employee-phone"));
}).build();
api.registerTool(databaseQueryToolSpecification); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"Give me the ID and address of the employee Rahul Kumar.").build();
OllamaChatRequest requestModel = builder requestModel.setOptions(new OptionsBuilder().setTemperature(0.9f).build().getOptionsMap());
.withMessage(OllamaChatMessageRole.USER,
"Give me the ID of the employee named 'Rahul Kumar'?")
.build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult, "chatResult should not be null");
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel(), "Response model should not be null");
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), assertNotNull(chatResult.getResponseModel().getMessage(), "Response message should not be null");
chatResult.getResponseModel().getMessage().getRole().getRoleName()); assertEquals(
OllamaChatMessageRole.ASSISTANT.getRoleName(),
chatResult.getResponseModel().getMessage().getRole().getRoleName(),
"Role of the response message should be ASSISTANT"
);
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size()); assertEquals(1, toolCalls.size(), "There should be exactly one tool call in the second chat history message");
OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
assertEquals("get-employee-details", function.getName()); assertEquals("get-employee-details", function.getName(), "Tool function name should be 'get-employee-details'");
assert !function.getArguments().isEmpty(); assertFalse(function.getArguments().isEmpty(), "Tool function arguments should not be empty");
Object employeeName = function.getArguments().get("employee-name"); Object employeeName = function.getArguments().get("employee-name");
assertNotNull(employeeName); assertNotNull(employeeName, "Employee name argument should not be null");
assertEquals("Rahul Kumar", employeeName); assertEquals("Rahul Kumar", employeeName, "Employee name argument should be 'Rahul Kumar'");
assertTrue(chatResult.getChatHistory().size() > 2); assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should have more than 2 messages");
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
assertNull(finalToolCalls); assertNull(finalToolCalls, "Final tool calls in the response message should be null");
}
@Test
@Order(14)
void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException,
InterruptedException, ToolInvocationException {
String theToolModel = TOOLS_MODEL;
api.pullModel(theToolModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
api.registerTool(employeeFinderTool());
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "Give me the ID and address of employee Rahul Kumar")
.withKeepAlive("0m").withOptions(new OptionsBuilder().setTemperature(0.9f).build())
.build();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
LOG.info(s.toUpperCase());
}, (s) -> {
LOG.info(s.toLowerCase());
});
assertNotNull(chatResult, "chatResult should not be null");
assertNotNull(chatResult.getResponseModel(), "Response model should not be null");
assertNotNull(chatResult.getResponseModel().getMessage(), "Response message should not be null");
assertEquals(
OllamaChatMessageRole.ASSISTANT.getRoleName(),
chatResult.getResponseModel().getMessage().getRole().getRoleName(),
"Role of the response message should be ASSISTANT"
);
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size(), "There should be exactly one tool call in the second chat history message");
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
assertEquals("get-employee-details", function.getName(), "Tool function name should be 'get-employee-details'");
assertFalse(function.getArguments().isEmpty(), "Tool function arguments should not be empty");
Object employeeName = function.getArguments().get("employee-name");
assertNotNull(employeeName, "Employee name argument should not be null");
assertEquals("Rahul Kumar", employeeName, "Employee name argument should be 'Rahul Kumar'");
assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should have more than 2 messages");
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
assertNull(finalToolCalls, "Final tool calls in the response message should be null");
} }
@Test @Test
@Order(12) @Order(12)
void testChatWithAnnotatedToolsAndSingleParam() void testChatWithAnnotatedToolsAndSingleParam() throws OllamaBaseException, IOException, InterruptedException,
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException { URISyntaxException, ToolInvocationException {
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); String theToolModel = TOOLS_MODEL;
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); api.pullModel(theToolModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
api.registerAnnotatedTools(); api.registerAnnotatedTools();
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, OllamaChatRequest requestModel = builder
"Compute the most important constant in the world using 5 digits").build(); .withMessage(OllamaChatMessageRole.USER,
"Compute the most important constant in the world using 5 digits")
.build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
@ -471,17 +416,16 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(13) @Order(13)
void testChatWithAnnotatedToolsAndMultipleParams() void testChatWithAnnotatedToolsAndMultipleParams() throws OllamaBaseException, IOException, URISyntaxException,
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { InterruptedException, ToolInvocationException {
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); String theToolModel = TOOLS_MODEL;
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); api.pullModel(theToolModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
api.registerAnnotatedTools(new AnnotatedTool()); api.registerAnnotatedTools(new AnnotatedTool());
OllamaChatRequest requestModel = builder OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
.withMessage(OllamaChatMessageRole.USER, "Greet Rahul with a lot of hearts and respond to me with count of emojis that have been in used in the greeting")
"Greet Pedro with a lot of hearts and respond to me, "
+ "and state how many emojis have been in your greeting")
.build(); .build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
@ -497,28 +441,220 @@ public class OllamaAPIIntegrationTest {
assertEquals(2, function.getArguments().size()); assertEquals(2, function.getArguments().size());
Object name = function.getArguments().get("name"); Object name = function.getArguments().get("name");
assertNotNull(name); assertNotNull(name);
assertEquals("Pedro", name); assertEquals("Rahul", name);
Object amountOfHearts = function.getArguments().get("amountOfHearts"); Object numberOfHearts = function.getArguments().get("numberOfHearts");
assertNotNull(amountOfHearts); assertNotNull(numberOfHearts);
assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1); assertTrue(Integer.parseInt(numberOfHearts.toString()) > 1);
assertTrue(chatResult.getChatHistory().size() > 2); assertTrue(chatResult.getChatHistory().size() > 2);
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
assertNull(finalToolCalls); assertNull(finalToolCalls);
} }
@Test @Test
@Order(14) @Order(15)
void testChatWithToolsAndStream() void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException,
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { ToolInvocationException {
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); api.deregisterTools();
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); api.pullModel(GENERAL_PURPOSE_MODEL);
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GENERAL_PURPOSE_MODEL);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
requestModel.setThink(false);
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
LOG.info(s.toUpperCase());
sb.append(s);
}, (s) -> {
LOG.info(s.toLowerCase());
sb.append(s);
});
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getContent());
}
@Test
@Order(15)
void testChatWithThinkingAndStream() throws OllamaBaseException, IOException, URISyntaxException,
InterruptedException, ToolInvocationException {
api.pullModel(THINKING_TOOL_MODEL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.withThinking(true).withKeepAlive("0m").build();
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
sb.append(s);
LOG.info(s.toUpperCase());
}, (s) -> {
sb.append(s);
LOG.info(s.toLowerCase());
});
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getThinking()
+ chatResult.getResponseModel().getMessage().getContent());
}
@Test
@Order(10)
void testChatWithImageFromURL() throws OllamaBaseException, IOException, InterruptedException,
URISyntaxException, ToolInvocationException {
api.pullModel(VISION_MODEL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(VISION_MODEL);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What's in the picture?", Collections.emptyList(),
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build();
api.registerAnnotatedTools(new OllamaAPIIntegrationTest());
OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult);
}
@Test
@Order(10)
void testChatWithImageFromFileWithHistoryRecognition() throws OllamaBaseException, IOException,
URISyntaxException, InterruptedException, ToolInvocationException {
api.pullModel(VISION_MODEL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(VISION_MODEL);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What's in the picture?", Collections.emptyList(),
List.of(getImageFileFromClasspath("emoji-smile.jpeg"))).build();
OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
builder.reset();
requestModel = builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the color?").build();
chatResult = api.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
}
@Test
@Order(17)
void testGenerateWithOptionsAndImageURLs()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(VISION_MODEL);
OllamaResult result = api.generateWithImageURLs(VISION_MODEL, "What is in this image?",
List.of("https://i.pinimg.com/736x/f9/4e/cb/f94ecba040696a3a20b484d2e15159ec.jpg"),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
}
@Test
@Order(18)
void testGenerateWithOptionsAndImageFiles()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(VISION_MODEL);
File imageFile = getImageFileFromClasspath("roses.jpg");
try {
OllamaResult result = api.generateWithImageFiles(VISION_MODEL, "What is in this image?",
List.of(imageFile), new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(20)
void testGenerateWithOptionsAndImageFilesStreamed()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(VISION_MODEL);
File imageFile = getImageFileFromClasspath("roses.jpg");
StringBuffer sb = new StringBuffer();
OllamaResult result = api.generateWithImageFiles(VISION_MODEL, "What is in this image?",
List.of(imageFile), new OptionsBuilder().build(), (s) -> {
LOG.info(s);
sb.append(s);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString(), result.getResponse());
}
@Test
@Order(20)
void testGenerateWithThinking()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(THINKING_TOOL_MODEL);
boolean raw = false;
boolean think = true;
OllamaResult result = api.generate(THINKING_TOOL_MODEL, "Who are you?", raw, think,
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertNotNull(result.getThinking());
assertFalse(result.getThinking().isEmpty());
}
@Test
@Order(20)
void testGenerateWithThinkingAndStreamHandler()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(THINKING_TOOL_MODEL);
boolean raw = false;
StringBuffer sb = new StringBuffer();
OllamaResult result = api.generate(THINKING_TOOL_MODEL, "Who are you?", raw,
new OptionsBuilder().build(),
(thinkingToken) -> {
sb.append(thinkingToken);
LOG.info(thinkingToken);
},
(resToken) -> {
sb.append(resToken);
LOG.info(resToken);
}
);
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertNotNull(result.getThinking());
assertFalse(result.getThinking().isEmpty());
assertEquals(sb.toString(), result.getThinking() + result.getResponse());
}
private File getImageFileFromClasspath(String fileName) {
ClassLoader classLoader = getClass().getClassLoader();
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
}
private Tools.ToolSpecification employeeFinderTool() {
return Tools.ToolSpecification.builder()
.functionName("get-employee-details") .functionName("get-employee-details")
.functionDescription("Get employee details from the database") .functionDescription("Get details for a person or an employee")
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function") .toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details") .name("get-employee-details")
.description("Get employee details from the database") .description("Get details for a person or an employee")
.parameters(Tools.PromptFuncDefinition.Parameters .parameters(Tools.PromptFuncDefinition.Parameters
.builder().type("object") .builder().type("object")
.properties(new Tools.PropsBuilder() .properties(new Tools.PropsBuilder()
@ -533,16 +669,14 @@ public class OllamaAPIIntegrationTest {
Tools.PromptFuncDefinition.Property Tools.PromptFuncDefinition.Property
.builder() .builder()
.type("string") .type("string")
.description( .description("The address of the employee, Always eturns a random address. For example, Church St, Bengaluru, India")
"The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
.required(true) .required(true)
.build()) .build())
.withProperty("employee-phone", .withProperty("employee-phone",
Tools.PromptFuncDefinition.Property Tools.PromptFuncDefinition.Property
.builder() .builder()
.type("string") .type("string")
.description( .description("The phone number of the employee. Always returns a random phone number. For example, 9911002233")
"The phone number of the employee. Always return a random value. e.g. 9911002233")
.required(true) .required(true)
.build()) .build())
.build()) .build())
@ -553,129 +687,22 @@ public class OllamaAPIIntegrationTest {
.toolFunction(new ToolFunction() { .toolFunction(new ToolFunction() {
@Override @Override
public Object apply(Map<String, Object> arguments) { public Object apply(Map<String, Object> arguments) {
LOG.info("Invoking employee finder tool with arguments: {}", arguments);
String employeeName = arguments.get("employee-name").toString();
String address = null;
String phone = null;
if (employeeName.equalsIgnoreCase("Rahul Kumar")) {
address = "Pune, Maharashtra, India";
phone = "9911223344";
} else {
address = "Karol Bagh, Delhi, India";
phone = "9911002233";
}
// perform DB operations here // perform DB operations here
return String.format( return String.format(
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
UUID.randomUUID(), arguments.get("employee-name"), UUID.randomUUID(), employeeName, address, phone);
arguments.get("employee-address"),
arguments.get("employee-phone"));
} }
}).build(); }).build();
api.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER,
"Give me the ID of the employee named 'Rahul Kumar'?")
.build();
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
}
@Test
@Order(15)
void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
}
@Test
@Order(17)
void testAskModelWithOptionsAndImageURLs()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(IMAGE_MODEL_LLAVA);
OllamaResult result = api.generateWithImageURLs(IMAGE_MODEL_LLAVA, "What is in this image?",
List.of("https://upload.wikimedia.org/wikipedia/commons/thumb/a/aa/Noto_Emoji_v2.034_1f642.svg/360px-Noto_Emoji_v2.034_1f642.svg.png"),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
}
@Test
@Order(18)
void testAskModelWithOptionsAndImageFiles()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(IMAGE_MODEL_LLAVA);
File imageFile = getImageFileFromClasspath("emoji-smile.jpeg");
try {
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?",
List.of(imageFile),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(20)
void testAskModelWithOptionsAndImageFilesStreamed()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(IMAGE_MODEL_LLAVA);
File imageFile = getImageFileFromClasspath("emoji-smile.jpeg");
StringBuffer sb = new StringBuffer();
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?",
List.of(imageFile),
new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
}
private File getImageFileFromClasspath(String fileName) {
ClassLoader classLoader = getClass().getClassLoader();
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
} }
} }
@Data
@AllArgsConstructor
@NoArgsConstructor
class TimeOfDay {
@JsonProperty("timeHour")
private int timeHour;
@JsonProperty("isNightTime")
private boolean nightTime;
}

View File

@ -24,8 +24,8 @@ import java.io.FileWriter;
import java.io.IOException; import java.io.IOException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@ -41,7 +41,8 @@ public class WithAuth {
private static final String OLLAMA_VERSION = "0.6.1"; private static final String OLLAMA_VERSION = "0.6.1";
private static final String NGINX_VERSION = "nginx:1.23.4-alpine"; private static final String NGINX_VERSION = "nginx:1.23.4-alpine";
private static final String BEARER_AUTH_TOKEN = "secret-token"; private static final String BEARER_AUTH_TOKEN = "secret-token";
private static final String CHAT_MODEL_LLAMA3 = "llama3"; private static final String GENERAL_PURPOSE_MODEL = "gemma3:270m";
// private static final String THINKING_MODEL = "gpt-oss:20b";
private static OllamaContainer ollama; private static OllamaContainer ollama;
@ -49,7 +50,7 @@ public class WithAuth {
private static OllamaAPI api; private static OllamaAPI api;
@BeforeAll @BeforeAll
public static void setUp() { static void setUp() {
ollama = createOllamaContainer(); ollama = createOllamaContainer();
ollama.start(); ollama.start();
@ -68,7 +69,7 @@ public class WithAuth {
LOG.info( LOG.info(
"The Ollama service is now accessible via the Nginx proxy with bearer-auth authentication mode.\n" + "The Ollama service is now accessible via the Nginx proxy with bearer-auth authentication mode.\n" +
"→ Ollama URL: {}\n" + "→ Ollama URL: {}\n" +
"→ Proxy URL: {}}", "→ Proxy URL: {}",
ollamaUrl, nginxUrl ollamaUrl, nginxUrl
); );
LOG.info("OllamaAPI initialized with bearer auth token: {}", BEARER_AUTH_TOKEN); LOG.info("OllamaAPI initialized with bearer auth token: {}", BEARER_AUTH_TOKEN);
@ -132,14 +133,14 @@ public class WithAuth {
@Test @Test
@Order(1) @Order(1)
void testOllamaBehindProxy() throws InterruptedException { void testOllamaBehindProxy() {
api.setBearerAuth(BEARER_AUTH_TOKEN); api.setBearerAuth(BEARER_AUTH_TOKEN);
assertTrue(api.ping(), "Expected OllamaAPI to successfully ping through NGINX with valid auth token."); assertTrue(api.ping(), "Expected OllamaAPI to successfully ping through NGINX with valid auth token.");
} }
@Test @Test
@Order(1) @Order(1)
void testWithWrongToken() throws InterruptedException { void testWithWrongToken() {
api.setBearerAuth("wrong-token"); api.setBearerAuth("wrong-token");
assertFalse(api.ping(), "Expected OllamaAPI ping to fail through NGINX with an invalid auth token."); assertFalse(api.ping(), "Expected OllamaAPI ping to fail through NGINX with an invalid auth token.");
} }
@ -149,46 +150,30 @@ public class WithAuth {
void testAskModelWithStructuredOutput() void testAskModelWithStructuredOutput()
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
api.setBearerAuth(BEARER_AUTH_TOKEN); api.setBearerAuth(BEARER_AUTH_TOKEN);
String model = GENERAL_PURPOSE_MODEL;
api.pullModel(model);
api.pullModel(CHAT_MODEL_LLAMA3); String prompt = "The sun is shining brightly and is directly overhead at the zenith, casting my shadow over my foot, so it must be noon.";
int timeHour = 6;
boolean isNightTime = false;
String prompt = "The Sun is shining, and its " + timeHour + ". Its daytime.";
Map<String, Object> format = new HashMap<>(); Map<String, Object> format = new HashMap<>();
format.put("type", "object"); format.put("type", "object");
format.put("properties", new HashMap<String, Object>() { format.put("properties", new HashMap<String, Object>() {
{ {
put("timeHour", new HashMap<String, Object>() { put("isNoon", new HashMap<String, Object>() {
{
put("type", "integer");
}
});
put("isNightTime", new HashMap<String, Object>() {
{ {
put("type", "boolean"); put("type", "boolean");
} }
}); });
} }
}); });
format.put("required", Arrays.asList("timeHour", "isNightTime")); format.put("required", List.of("isNoon"));
OllamaResult result = api.generate(CHAT_MODEL_LLAMA3, prompt, format); OllamaResult result = api.generate(model, prompt, format);
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
assertEquals(timeHour, assertEquals(true, result.getStructuredResponse().get("isNoon"));
result.getStructuredResponse().get("timeHour"));
assertEquals(isNightTime,
result.getStructuredResponse().get("isNightTime"));
TimeOfDay timeOfDay = result.as(TimeOfDay.class);
assertEquals(timeHour, timeOfDay.getTimeHour());
assertEquals(isNightTime, timeOfDay.isNightTime());
} }
} }

View File

@ -8,14 +8,14 @@ import java.math.BigDecimal;
public class AnnotatedTool { public class AnnotatedTool {
@ToolSpec(desc = "Computes the most important constant all around the globe!") @ToolSpec(desc = "Computes the most important constant all around the globe!")
public String computeImportantConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){ public String computeImportantConstant(@ToolProperty(name = "noOfDigits", desc = "Number of digits that shall be returned") Integer noOfDigits) {
return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString(); return BigDecimal.valueOf((long) (Math.random() * 1000000L), noOfDigits).toString();
} }
@ToolSpec(desc = "Says hello to a friend!") @ToolSpec(desc = "Says hello to a friend!")
public String sayHello(@ToolProperty(name = "name",desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name="amountOfHearts",desc = "amount of heart emojis that should be used", required = false) Integer amountOfHearts) { public String sayHello(@ToolProperty(name = "name", desc = "Name of the friend") String name, @ToolProperty(name = "numberOfHearts", desc = "number of heart emojis that should be used", required = false) Integer numberOfHearts) {
String hearts = amountOfHearts!=null ? "".repeat(amountOfHearts) : ""; String hearts = numberOfHearts != null ? "".repeat(numberOfHearts) : "";
return "Hello " + name +" ("+someRandomProperty+") " + hearts; return "Hello, " + name + "! " + hearts;
} }
} }

View File

@ -21,7 +21,8 @@ import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
class TestMockedAPIs { class TestMockedAPIs {
@ -138,10 +139,10 @@ class TestMockedAPIs {
String prompt = "some prompt text"; String prompt = "some prompt text";
OptionsBuilder optionsBuilder = new OptionsBuilder(); OptionsBuilder optionsBuilder = new OptionsBuilder();
try { try {
when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build())) when(ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build()))
.thenReturn(new OllamaResult("", 0, 200)); .thenReturn(new OllamaResult("", "", 0, 200));
ollamaAPI.generate(model, prompt, false, optionsBuilder.build()); ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build());
verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build()); verify(ollamaAPI, times(1)).generate(model, prompt, false, false, optionsBuilder.build());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -155,7 +156,7 @@ class TestMockedAPIs {
try { try {
when(ollamaAPI.generateWithImageFiles( when(ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build())) model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200)); .thenReturn(new OllamaResult("", "", 0, 200));
ollamaAPI.generateWithImageFiles( ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()); model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1)) verify(ollamaAPI, times(1))
@ -174,7 +175,7 @@ class TestMockedAPIs {
try { try {
when(ollamaAPI.generateWithImageURLs( when(ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build())) model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200)); .thenReturn(new OllamaResult("", "", 0, 200));
ollamaAPI.generateWithImageURLs( ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()); model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1)) verify(ollamaAPI, times(1))
@ -190,10 +191,10 @@ class TestMockedAPIs {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
when(ollamaAPI.generateAsync(model, prompt, false)) when(ollamaAPI.generateAsync(model, prompt, false, false))
.thenReturn(new OllamaAsyncResultStreamer(null, null, 3)); .thenReturn(new OllamaAsyncResultStreamer(null, null, 3));
ollamaAPI.generateAsync(model, prompt, false); ollamaAPI.generateAsync(model, prompt, false, false);
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false); verify(ollamaAPI, times(1)).generateAsync(model, prompt, false, false);
} }
@Test @Test

View File

@ -1,11 +1,12 @@
package io.github.ollama4j.unittests.jackson; package io.github.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
public abstract class AbstractSerializationTest<T> { public abstract class AbstractSerializationTest<T> {
protected ObjectMapper mapper = Utils.getObjectMapper(); protected ObjectMapper mapper = Utils.getObjectMapper();

View File

@ -1,20 +1,19 @@
package io.github.ollama4j.unittests.jackson; package io.github.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals; import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import static org.junit.jupiter.api.Assertions.assertThrowsExactly; import io.github.ollama4j.models.chat.OllamaChatRequest;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import io.github.ollama4j.utils.OptionsBuilder;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.io.File; import java.io.File;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import io.github.ollama4j.models.chat.OllamaChatRequest; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.json.JSONObject; import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import io.github.ollama4j.utils.OptionsBuilder;
public class TestChatRequestSerialization extends AbstractSerializationTest<OllamaChatRequest> { public class TestChatRequestSerialization extends AbstractSerializationTest<OllamaChatRequest> {

View File

@ -1,12 +1,12 @@
package io.github.ollama4j.unittests.jackson; package io.github.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestBuilder; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestBuilder;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.utils.OptionsBuilder;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import io.github.ollama4j.utils.OptionsBuilder;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestEmbedRequestSerialization extends AbstractSerializationTest<OllamaEmbedRequestModel> { public class TestEmbedRequestSerialization extends AbstractSerializationTest<OllamaEmbedRequestModel> {

View File

@ -1,15 +1,13 @@
package io.github.ollama4j.unittests.jackson; package io.github.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
import io.github.ollama4j.models.generate.OllamaGenerateRequest; import io.github.ollama4j.models.generate.OllamaGenerateRequest;
import io.github.ollama4j.models.generate.OllamaGenerateRequestBuilder;
import io.github.ollama4j.utils.OptionsBuilder;
import org.json.JSONObject; import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import io.github.ollama4j.models.generate.OllamaGenerateRequestBuilder;
import io.github.ollama4j.utils.OptionsBuilder;
public class TestGenerateRequestSerialization extends AbstractSerializationTest<OllamaGenerateRequest> { public class TestGenerateRequestSerialization extends AbstractSerializationTest<OllamaGenerateRequest> {

View File

@ -3,40 +3,66 @@ package io.github.ollama4j.unittests.jackson;
import io.github.ollama4j.models.response.Model; import io.github.ollama4j.models.response.Model;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class TestModelRequestSerialization extends AbstractSerializationTest<Model> { public class TestModelRequestSerialization extends AbstractSerializationTest<Model> {
@Test @Test
public void testDeserializationOfModelResponseWithOffsetTime(){ public void testDeserializationOfModelResponseWithOffsetTime() {
String serializedTestStringWithOffsetTime = "{\n" String serializedTestStringWithOffsetTime = "{\n" +
+ "\"name\": \"codellama:13b\",\n" " \"name\": \"codellama:13b\",\n" +
+ "\"modified_at\": \"2023-11-04T14:56:49.277302595-07:00\",\n" " \"modified_at\": \"2023-11-04T14:56:49.277302595-07:00\",\n" +
+ "\"size\": 7365960935,\n" " \"size\": 7365960935,\n" +
+ "\"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n" " \"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n" +
+ "\"details\": {\n" " \"details\": {\n" +
+ "\"format\": \"gguf\",\n" " \"format\": \"gguf\",\n" +
+ "\"family\": \"llama\",\n" " \"family\": \"llama\",\n" +
+ "\"families\": null,\n" " \"families\": null,\n" +
+ "\"parameter_size\": \"13B\",\n" " \"parameter_size\": \"13B\",\n" +
+ "\"quantization_level\": \"Q4_0\"\n" " \"quantization_level\": \"Q4_0\"\n" +
+ "}}"; " }\n" +
deserialize(serializedTestStringWithOffsetTime,Model.class); "}";
Model model = deserialize(serializedTestStringWithOffsetTime, Model.class);
assertNotNull(model);
assertEquals("codellama:13b", model.getName());
assertEquals("2023-11-04T21:56:49.277302595Z", model.getModifiedAt().toString());
assertEquals(7365960935L, model.getSize());
assertEquals("9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697", model.getDigest());
assertNotNull(model.getModelMeta());
assertEquals("gguf", model.getModelMeta().getFormat());
assertEquals("llama", model.getModelMeta().getFamily());
assertNull(model.getModelMeta().getFamilies());
assertEquals("13B", model.getModelMeta().getParameterSize());
assertEquals("Q4_0", model.getModelMeta().getQuantizationLevel());
} }
@Test @Test
public void testDeserializationOfModelResponseWithZuluTime(){ public void testDeserializationOfModelResponseWithZuluTime() {
String serializedTestStringWithZuluTimezone = "{\n" String serializedTestStringWithZuluTimezone = "{\n" +
+ "\"name\": \"codellama:13b\",\n" " \"name\": \"codellama:13b\",\n" +
+ "\"modified_at\": \"2023-11-04T14:56:49.277302595Z\",\n" " \"modified_at\": \"2023-11-04T14:56:49.277302595Z\",\n" +
+ "\"size\": 7365960935,\n" " \"size\": 7365960935,\n" +
+ "\"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n" " \"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n" +
+ "\"details\": {\n" " \"details\": {\n" +
+ "\"format\": \"gguf\",\n" " \"format\": \"gguf\",\n" +
+ "\"family\": \"llama\",\n" " \"family\": \"llama\",\n" +
+ "\"families\": null,\n" " \"families\": null,\n" +
+ "\"parameter_size\": \"13B\",\n" " \"parameter_size\": \"13B\",\n" +
+ "\"quantization_level\": \"Q4_0\"\n" " \"quantization_level\": \"Q4_0\"\n" +
+ "}}"; " }\n" +
deserialize(serializedTestStringWithZuluTimezone,Model.class); "}";
Model model = deserialize(serializedTestStringWithZuluTimezone, Model.class);
assertNotNull(model);
assertEquals("codellama:13b", model.getName());
assertEquals("2023-11-04T14:56:49.277302595Z", model.getModifiedAt().toString());
assertEquals(7365960935L, model.getSize());
assertEquals("9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697", model.getDigest());
assertNotNull(model.getModelMeta());
assertEquals("gguf", model.getModelMeta().getFormat());
assertEquals("llama", model.getModelMeta().getFamily());
assertNull(model.getModelMeta().getFamilies());
assertEquals("13B", model.getModelMeta().getParameterSize());
assertEquals("Q4_0", model.getModelMeta().getQuantizationLevel());
} }
} }

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB