mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-09-16 03:39:05 +02:00
Merge pull request #145 from ollama4j/thinking-support
Thinking support
This commit is contained in:
commit
931d5dd520
35
.github/workflows/build-on-pull-request.yml
vendored
35
.github/workflows/build-on-pull-request.yml
vendored
@ -1,20 +1,21 @@
|
||||
name: Run Tests
|
||||
name: Build and Test on Pull Request
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
# types: [opened, reopened, synchronize, edited]
|
||||
branches: [ "main" ]
|
||||
types: [opened, reopened, synchronize]
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'src/**' # Run if changes occur in the 'src' folder
|
||||
- 'pom.xml' # Run if changes occur in the 'pom.xml' file
|
||||
- 'src/**'
|
||||
- 'pom.xml'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run-tests:
|
||||
|
||||
build:
|
||||
name: Build Java Project
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@ -26,18 +27,26 @@ jobs:
|
||||
with:
|
||||
java-version: '11'
|
||||
distribution: 'adopt-hotspot'
|
||||
server-id: github # Value of the distributionManagement/repository/id field of the pom.xml
|
||||
settings-path: ${{ github.workspace }} # location for the settings.xml file
|
||||
server-id: github
|
||||
settings-path: ${{ github.workspace }}
|
||||
|
||||
- name: Build with Maven
|
||||
run: mvn --file pom.xml -U clean package
|
||||
|
||||
- name: Run unit tests
|
||||
run: mvn --file pom.xml -U clean test -Punit-tests
|
||||
run-tests:
|
||||
name: Run Unit and Integration Tests
|
||||
needs: build
|
||||
uses: ./.github/workflows/run-tests.yml
|
||||
with:
|
||||
branch: ${{ github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Run integration tests
|
||||
run: mvn --file pom.xml -U clean verify -Pintegration-tests
|
||||
build-docs:
|
||||
name: Build Documentation
|
||||
needs: [build, run-tests]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
|
24
.github/workflows/run-tests.yml
vendored
24
.github/workflows/run-tests.yml
vendored
@ -1,18 +1,29 @@
|
||||
name: Run Unit and Integration Tests
|
||||
name: Run Tests
|
||||
|
||||
on:
|
||||
# push:
|
||||
# branches:
|
||||
# - main
|
||||
|
||||
workflow_call:
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch name to run the tests on'
|
||||
required: true
|
||||
default: 'main'
|
||||
type: string
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch name to run the tests on'
|
||||
required: true
|
||||
default: 'main'
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
run-tests:
|
||||
name: Unit and Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@ -21,17 +32,6 @@ jobs:
|
||||
with:
|
||||
ref: ${{ github.event.inputs.branch }}
|
||||
|
||||
- name: Use workflow from checked out branch
|
||||
run: |
|
||||
if [ -f .github/workflows/run-tests.yml ]; then
|
||||
echo "Using workflow from checked out branch."
|
||||
cp .github/workflows/run-tests.yml /tmp/run-tests.yml
|
||||
exit 0
|
||||
else
|
||||
echo "Workflow file not found in checked out branch."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Set up Ollama
|
||||
run: |
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
@ -29,7 +29,7 @@ You will get a response similar to:
|
||||
|
||||
### Try asking a question, receiving the answer streamed
|
||||
|
||||
<CodeEmbed src="https://raw.githubusercontent.com/ollama4j/ollama4j-examples/refs/heads/main/src/main/java/io/github/ollama4j/examples/GenerateStreamingWithTokenConcatenation.java" />
|
||||
<CodeEmbed src="https://raw.githubusercontent.com/ollama4j/ollama4j-examples/refs/heads/main/src/main/java/io/github/ollama4j/examples/GenerateStreaming.java" />
|
||||
|
||||
You will get a response similar to:
|
||||
|
||||
|
18
pom.xml
18
pom.xml
@ -14,11 +14,12 @@
|
||||
|
||||
<properties>
|
||||
<maven.compiler.release>11</maven.compiler.release>
|
||||
<project.build.outputTimestamp>${git.commit.time}</project.build.outputTimestamp><!-- populated via git-commit-id-plugin -->
|
||||
<project.build.outputTimestamp>${git.commit.time}
|
||||
</project.build.outputTimestamp><!-- populated via git-commit-id-plugin -->
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven-surefire-plugin.version>3.0.0-M5</maven-surefire-plugin.version>
|
||||
<maven-failsafe-plugin.version>3.0.0-M5</maven-failsafe-plugin.version>
|
||||
<lombok.version>1.18.30</lombok.version>
|
||||
<lombok.version>1.18.38</lombok.version>
|
||||
</properties>
|
||||
|
||||
<developers>
|
||||
@ -46,6 +47,19 @@
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<configuration>
|
||||
<annotationProcessorPaths>
|
||||
<path>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<version>${lombok.version}</version>
|
||||
</path>
|
||||
</annotationProcessorPaths>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-source-plugin</artifactId>
|
||||
|
@ -22,6 +22,7 @@ import io.github.ollama4j.tools.*;
|
||||
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||
import io.github.ollama4j.tools.annotations.ToolProperty;
|
||||
import io.github.ollama4j.tools.annotations.ToolSpec;
|
||||
import io.github.ollama4j.utils.Constants;
|
||||
import io.github.ollama4j.utils.Options;
|
||||
import io.github.ollama4j.utils.Utils;
|
||||
import lombok.Setter;
|
||||
@ -55,33 +56,54 @@ import java.util.stream.Collectors;
|
||||
public class OllamaAPI {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
|
||||
|
||||
private final String host;
|
||||
private Auth auth;
|
||||
private final ToolRegistry toolRegistry = new ToolRegistry();
|
||||
|
||||
/**
|
||||
* -- SETTER --
|
||||
* Set request timeout in seconds. Default is 3 seconds.
|
||||
* The request timeout in seconds for API calls.
|
||||
* <p>
|
||||
* Default is 10 seconds. This value determines how long the client will wait
|
||||
* for a response
|
||||
* from the Ollama server before timing out.
|
||||
*/
|
||||
@Setter
|
||||
private long requestTimeoutSeconds = 10;
|
||||
|
||||
/**
|
||||
* -- SETTER --
|
||||
* Set/unset logging of responses
|
||||
* Enables or disables verbose logging of responses.
|
||||
* <p>
|
||||
* If set to {@code true}, the API will log detailed information about requests
|
||||
* and responses.
|
||||
* Default is {@code true}.
|
||||
*/
|
||||
@Setter
|
||||
private boolean verbose = true;
|
||||
|
||||
/**
|
||||
* The maximum number of retries for tool calls during chat interactions.
|
||||
* <p>
|
||||
* This value controls how many times the API will attempt to call a tool in the
|
||||
* event of a failure.
|
||||
* Default is 3.
|
||||
*/
|
||||
@Setter
|
||||
private int maxChatToolCallRetries = 3;
|
||||
|
||||
private Auth auth;
|
||||
|
||||
/**
|
||||
* The number of retries to attempt when pulling a model from the Ollama server.
|
||||
* <p>
|
||||
* If set to 0, no retries will be performed. If greater than 0, the API will
|
||||
* retry pulling the model
|
||||
* up to the specified number of times in case of failure.
|
||||
* <p>
|
||||
* Default is 0 (no retries).
|
||||
*/
|
||||
@Setter
|
||||
@SuppressWarnings({"FieldMayBeFinal", "FieldCanBeLocal"})
|
||||
private int numberOfRetriesForModelPull = 0;
|
||||
|
||||
public void setNumberOfRetriesForModelPull(int numberOfRetriesForModelPull) {
|
||||
this.numberOfRetriesForModelPull = numberOfRetriesForModelPull;
|
||||
}
|
||||
|
||||
private final ToolRegistry toolRegistry = new ToolRegistry();
|
||||
|
||||
/**
|
||||
* Instantiates the Ollama API with default Ollama host:
|
||||
* <a href="http://localhost:11434">http://localhost:11434</a>
|
||||
@ -102,7 +124,7 @@ public class OllamaAPI {
|
||||
this.host = host;
|
||||
}
|
||||
if (this.verbose) {
|
||||
logger.info("Ollama API initialized with host: " + this.host);
|
||||
logger.info("Ollama API initialized with host: {}", this.host);
|
||||
}
|
||||
}
|
||||
|
||||
@ -135,14 +157,17 @@ public class OllamaAPI {
|
||||
public boolean ping() {
|
||||
String url = this.host + "/api/tags";
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest httpRequest = null;
|
||||
HttpRequest httpRequest;
|
||||
try {
|
||||
httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
|
||||
.header("Content-type", "application/json").GET().build();
|
||||
httpRequest = getRequestBuilderDefault(new URI(url))
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.GET()
|
||||
.build();
|
||||
} catch (URISyntaxException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
HttpResponse<String> response = null;
|
||||
HttpResponse<String> response;
|
||||
try {
|
||||
response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
||||
} catch (HttpConnectTimeoutException e) {
|
||||
@ -168,8 +193,10 @@ public class OllamaAPI {
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest httpRequest = null;
|
||||
try {
|
||||
httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
|
||||
.header("Content-type", "application/json").GET().build();
|
||||
httpRequest = getRequestBuilderDefault(new URI(url))
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.GET().build();
|
||||
} catch (URISyntaxException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
@ -196,8 +223,10 @@ public class OllamaAPI {
|
||||
public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
String url = this.host + "/api/tags";
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
|
||||
.header("Content-type", "application/json").GET().build();
|
||||
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
|
||||
.build();
|
||||
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
String responseString = response.body();
|
||||
@ -229,8 +258,10 @@ public class OllamaAPI {
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
String url = "https://ollama.com/library";
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
|
||||
.header("Content-type", "application/json").GET().build();
|
||||
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
|
||||
.build();
|
||||
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
String responseString = response.body();
|
||||
@ -296,8 +327,10 @@ public class OllamaAPI {
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName());
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
|
||||
.header("Content-type", "application/json").GET().build();
|
||||
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
|
||||
.build();
|
||||
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
String responseString = response.body();
|
||||
@ -338,6 +371,14 @@ public class OllamaAPI {
|
||||
/**
|
||||
* Finds a specific model using model name and tag from Ollama library.
|
||||
* <p>
|
||||
* <b>Deprecated:</b> This method relies on the HTML structure of the Ollama
|
||||
* website,
|
||||
* which is subject to change at any time. As a result, it is difficult to keep
|
||||
* this API
|
||||
* method consistently updated and reliable. Therefore, this method is
|
||||
* deprecated and
|
||||
* may be removed in future releases.
|
||||
* <p>
|
||||
* This method retrieves the model from the Ollama library by its name, then
|
||||
* fetches its tags.
|
||||
* It searches through the tags of the model to find one that matches the
|
||||
@ -355,7 +396,11 @@ public class OllamaAPI {
|
||||
* @throws URISyntaxException If there is an error with the URI syntax.
|
||||
* @throws InterruptedException If the operation is interrupted.
|
||||
* @throws NoSuchElementException If the model or the tag is not found.
|
||||
* @deprecated This method relies on the HTML structure of the Ollama website,
|
||||
* which can change at any time and break this API. It is deprecated
|
||||
* and may be removed in the future.
|
||||
*/
|
||||
@Deprecated
|
||||
public LibraryModelTag findModelTagFromLibrary(String modelName, String tag)
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
List<LibraryModel> libraryModels = this.listModelsFromLibrary();
|
||||
@ -363,51 +408,81 @@ public class OllamaAPI {
|
||||
.findFirst().orElseThrow(
|
||||
() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName)));
|
||||
LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel);
|
||||
LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream()
|
||||
.filter(tagName -> tagName.getTag().equals(tag)).findFirst()
|
||||
return libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst()
|
||||
.orElseThrow(() -> new NoSuchElementException(
|
||||
String.format("Tag '%s' for model '%s' not found", tag, modelName)));
|
||||
return libraryModelTag;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pull a model on the Ollama server from the list of <a
|
||||
* href="https://ollama.ai/library">available models</a>.
|
||||
* <p>
|
||||
* If {@code numberOfRetriesForModelPull} is greater than 0, this method will
|
||||
* retry pulling the model
|
||||
* up to the specified number of times if an {@link OllamaBaseException} occurs,
|
||||
* using exponential backoff
|
||||
* between retries (delay doubles after each failed attempt, starting at 1
|
||||
* second).
|
||||
* <p>
|
||||
* The backoff is only applied between retries, not after the final attempt.
|
||||
*
|
||||
* @param modelName the name of the model
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws OllamaBaseException if the response indicates an error status or all
|
||||
* retries fail
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws InterruptedException if the operation is interrupted or the thread is
|
||||
* interrupted during backoff
|
||||
* @throws URISyntaxException if the URI for the request is malformed
|
||||
*/
|
||||
public void pullModel(String modelName)
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
if (numberOfRetriesForModelPull == 0) {
|
||||
this.doPullModel(modelName);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
int numberOfRetries = 0;
|
||||
long baseDelayMillis = 3000L; // 1 second base delay
|
||||
while (numberOfRetries < numberOfRetriesForModelPull) {
|
||||
try {
|
||||
this.doPullModel(modelName);
|
||||
return;
|
||||
} catch (OllamaBaseException e) {
|
||||
logger.error("Failed to pull model " + modelName + ", retrying...");
|
||||
handlePullRetry(modelName, numberOfRetries, numberOfRetriesForModelPull, baseDelayMillis);
|
||||
numberOfRetries++;
|
||||
}
|
||||
}
|
||||
throw new OllamaBaseException(
|
||||
"Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries");
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles retry backoff for pullModel.
|
||||
*/
|
||||
private void handlePullRetry(String modelName, int currentRetry, int maxRetries, long baseDelayMillis)
|
||||
throws InterruptedException {
|
||||
int attempt = currentRetry + 1;
|
||||
if (attempt < maxRetries) {
|
||||
long backoffMillis = baseDelayMillis * (1L << currentRetry);
|
||||
logger.error("Failed to pull model {}, retrying in {}s... (attempt {}/{})",
|
||||
modelName, backoffMillis / 1000, attempt, maxRetries);
|
||||
try {
|
||||
Thread.sleep(backoffMillis);
|
||||
} catch (InterruptedException ie) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw ie;
|
||||
}
|
||||
} else {
|
||||
logger.error("Failed to pull model {} after {} attempts, no more retries.", modelName, maxRetries);
|
||||
}
|
||||
}
|
||||
|
||||
private void doPullModel(String modelName)
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
String url = this.host + "/api/pull";
|
||||
String jsonData = new ModelRequest(modelName).toString();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url))
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
||||
.header("Accept", "application/json")
|
||||
.header("Content-type", "application/json")
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.build();
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
@ -428,7 +503,7 @@ public class OllamaAPI {
|
||||
|
||||
if (modelPullResponse.getStatus() != null) {
|
||||
if (verbose) {
|
||||
logger.info(modelName + ": " + modelPullResponse.getStatus());
|
||||
logger.info("{}: {}", modelName, modelPullResponse.getStatus());
|
||||
}
|
||||
// Check if status is "success" and set success flag to true.
|
||||
if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) {
|
||||
@ -452,8 +527,10 @@ public class OllamaAPI {
|
||||
public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException {
|
||||
String url = this.host + "/api/version";
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
|
||||
.header("Content-type", "application/json").GET().build();
|
||||
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
|
||||
.build();
|
||||
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
String responseString = response.body();
|
||||
@ -498,8 +575,10 @@ public class OllamaAPI {
|
||||
throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
|
||||
String url = this.host + "/api/show";
|
||||
String jsonData = new ModelRequest(modelName).toString();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
|
||||
.header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url))
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
@ -529,8 +608,9 @@ public class OllamaAPI {
|
||||
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
|
||||
String url = this.host + "/api/create";
|
||||
String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url))
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
@ -569,8 +649,9 @@ public class OllamaAPI {
|
||||
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
|
||||
String url = this.host + "/api/create";
|
||||
String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url))
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
@ -602,8 +683,9 @@ public class OllamaAPI {
|
||||
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
|
||||
String url = this.host + "/api/create";
|
||||
String jsonData = customModelRequest.toString();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url))
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
@ -637,7 +719,9 @@ public class OllamaAPI {
|
||||
String jsonData = new ModelRequest(modelName).toString();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url))
|
||||
.method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
|
||||
.header("Accept", "application/json").header("Content-type", "application/json").build();
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.build();
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
@ -683,7 +767,8 @@ public class OllamaAPI {
|
||||
URI uri = URI.create(this.host + "/api/embeddings");
|
||||
String jsonData = modelRequest.toString();
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json")
|
||||
HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData));
|
||||
HttpRequest request = requestBuilder.build();
|
||||
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
@ -728,7 +813,8 @@ public class OllamaAPI {
|
||||
String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
|
||||
HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json")
|
||||
HttpRequest request = HttpRequest.newBuilder(uri)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
|
||||
|
||||
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
@ -744,17 +830,24 @@ public class OllamaAPI {
|
||||
|
||||
/**
|
||||
* Generate response for a question to a model running on Ollama server. This is
|
||||
* a sync/blocking
|
||||
* call.
|
||||
* a sync/blocking call. This API does not support "thinking" models.
|
||||
*
|
||||
* @param model the ollama model to ask the question to
|
||||
* @param prompt the prompt/question text
|
||||
* @param raw if true no formatting will be applied to the
|
||||
* prompt. You
|
||||
* may choose to use the raw parameter if you are
|
||||
* specifying a full templated prompt in your
|
||||
* request to
|
||||
* the API
|
||||
* @param options the Options object - <a
|
||||
* href=
|
||||
* "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
|
||||
* details on the options</a>
|
||||
* @param streamHandler optional callback consumer that will be applied every
|
||||
* time a streamed response is received. If not set, the
|
||||
* @param responseStreamHandler optional callback consumer that will be applied
|
||||
* every
|
||||
* time a streamed response is received. If not
|
||||
* set, the
|
||||
* stream parameter of the request is set to false.
|
||||
* @return OllamaResult that includes response text and time taken for response
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
@ -762,15 +855,87 @@ public class OllamaAPI {
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaResult generate(String model, String prompt, boolean raw, Options options,
|
||||
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaStreamHandler responseStreamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
|
||||
ollamaRequestModel.setRaw(raw);
|
||||
ollamaRequestModel.setThink(false);
|
||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
|
||||
return generateSyncForOllamaRequestModel(ollamaRequestModel, null, responseStreamHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate thinking and response tokens for a question to a thinking model
|
||||
* running on Ollama server. This is
|
||||
* a sync/blocking call.
|
||||
*
|
||||
* @param model the ollama model to ask the question to
|
||||
* @param prompt the prompt/question text
|
||||
* @param raw if true no formatting will be applied to the
|
||||
* prompt. You
|
||||
* may choose to use the raw parameter if you are
|
||||
* specifying a full templated prompt in your
|
||||
* request to
|
||||
* the API
|
||||
* @param options the Options object - <a
|
||||
* href=
|
||||
* "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
|
||||
* details on the options</a>
|
||||
* @param responseStreamHandler optional callback consumer that will be applied
|
||||
* every
|
||||
* time a streamed response is received. If not
|
||||
* set, the
|
||||
* stream parameter of the request is set to false.
|
||||
* @return OllamaResult that includes response text and time taken for response
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaResult generate(String model, String prompt, boolean raw, Options options,
|
||||
OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
|
||||
ollamaRequestModel.setRaw(raw);
|
||||
ollamaRequestModel.setThink(true);
|
||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||
return generateSyncForOllamaRequestModel(ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates response using the specified AI model and prompt (in blocking
|
||||
* mode).
|
||||
* <p>
|
||||
* Uses
|
||||
* {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
|
||||
*
|
||||
* @param model The name or identifier of the AI model to use for generating
|
||||
* the response.
|
||||
* @param prompt The input text or prompt to provide to the AI model.
|
||||
* @param raw In some cases, you may wish to bypass the templating system
|
||||
* and provide a full prompt. In this case, you can use the raw
|
||||
* parameter to disable templating. Also note that raw mode will
|
||||
* not return a context.
|
||||
* @param options Additional options or configurations to use when generating
|
||||
* the response.
|
||||
* @param think if true the model will "think" step-by-step before
|
||||
* generating the final response
|
||||
* @return {@link OllamaResult}
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
if (think) {
|
||||
return generate(model, prompt, raw, options, null, null);
|
||||
} else {
|
||||
return generate(model, prompt, raw, options, null);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates structured output from the specified AI model and prompt.
|
||||
* <p>
|
||||
* Note: When formatting is specified, the 'think' parameter is not allowed.
|
||||
*
|
||||
* @param model The name or identifier of the AI model to use for generating
|
||||
* the response.
|
||||
@ -783,6 +948,7 @@ public class OllamaAPI {
|
||||
* @throws IOException if an I/O error occurs during the HTTP request.
|
||||
* @throws InterruptedException if the operation is interrupted.
|
||||
*/
|
||||
@SuppressWarnings("LoggingSimilarMessage")
|
||||
public OllamaResult generate(String model, String prompt, Map<String, Object> format)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
URI uri = URI.create(this.host + "/api/generate");
|
||||
@ -797,51 +963,52 @@ public class OllamaAPI {
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
|
||||
HttpRequest request = getRequestBuilderDefault(uri)
|
||||
.header("Accept", "application/json")
|
||||
.header("Content-type", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
|
||||
.build();
|
||||
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
|
||||
|
||||
if (verbose) {
|
||||
try {
|
||||
String prettyJson = Utils.getObjectMapper().writerWithDefaultPrettyPrinter()
|
||||
.writeValueAsString(Utils.getObjectMapper().readValue(jsonData, Object.class));
|
||||
logger.info("Asking model:\n{}", prettyJson);
|
||||
} catch (Exception e) {
|
||||
logger.info("Asking model: {}", jsonData);
|
||||
}
|
||||
}
|
||||
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
int statusCode = response.statusCode();
|
||||
String responseBody = response.body();
|
||||
|
||||
if (statusCode == 200) {
|
||||
OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody,
|
||||
OllamaStructuredResult.class);
|
||||
OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(),
|
||||
OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(),
|
||||
structuredResult.getResponseTime(), statusCode);
|
||||
|
||||
ollamaResult.setModel(structuredResult.getModel());
|
||||
ollamaResult.setCreatedAt(structuredResult.getCreatedAt());
|
||||
ollamaResult.setDone(structuredResult.isDone());
|
||||
ollamaResult.setDoneReason(structuredResult.getDoneReason());
|
||||
ollamaResult.setContext(structuredResult.getContext());
|
||||
ollamaResult.setTotalDuration(structuredResult.getTotalDuration());
|
||||
ollamaResult.setLoadDuration(structuredResult.getLoadDuration());
|
||||
ollamaResult.setPromptEvalCount(structuredResult.getPromptEvalCount());
|
||||
ollamaResult.setPromptEvalDuration(structuredResult.getPromptEvalDuration());
|
||||
ollamaResult.setEvalCount(structuredResult.getEvalCount());
|
||||
ollamaResult.setEvalDuration(structuredResult.getEvalDuration());
|
||||
if (verbose) {
|
||||
logger.info("Model response:\n{}", ollamaResult);
|
||||
}
|
||||
return ollamaResult;
|
||||
} else {
|
||||
if (verbose) {
|
||||
logger.info("Model response:\n{}",
|
||||
Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseBody));
|
||||
}
|
||||
throw new OllamaBaseException(statusCode + " - " + responseBody);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates response using the specified AI model and prompt (in blocking
|
||||
* mode).
|
||||
* <p>
|
||||
* Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
|
||||
*
|
||||
* @param model The name or identifier of the AI model to use for generating
|
||||
* the response.
|
||||
* @param prompt The input text or prompt to provide to the AI model.
|
||||
* @param raw In some cases, you may wish to bypass the templating system
|
||||
* and provide a full prompt. In this case, you can use the raw
|
||||
* parameter to disable templating. Also note that raw mode will
|
||||
* not return a context.
|
||||
* @param options Additional options or configurations to use when generating
|
||||
* the response.
|
||||
* @return {@link OllamaResult}
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaResult generate(String model, String prompt, boolean raw, Options options)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
return generate(model, prompt, raw, options, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates response using the specified AI model and prompt (in blocking
|
||||
* mode), and then invokes a set of tools
|
||||
@ -893,8 +1060,7 @@ public class OllamaAPI {
|
||||
logger.warn("Response from model does not contain any tool calls. Returning the response as is.");
|
||||
return toolResult;
|
||||
}
|
||||
toolFunctionCallSpecs = objectMapper.readValue(
|
||||
toolsResponse,
|
||||
toolFunctionCallSpecs = objectMapper.readValue(toolsResponse,
|
||||
objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
|
||||
}
|
||||
for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
|
||||
@ -905,19 +1071,47 @@ public class OllamaAPI {
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate response for a question to a model running on Ollama server and get
|
||||
* a callback handle
|
||||
* that can be used to check for status and get the response from the model
|
||||
* later. This would be
|
||||
* an async/non-blocking call.
|
||||
* Asynchronously generates a response for a prompt using a model running on the
|
||||
* Ollama server.
|
||||
* <p>
|
||||
* This method returns an {@link OllamaAsyncResultStreamer} handle that can be
|
||||
* used to poll for
|
||||
* status and retrieve streamed "thinking" and response tokens from the model.
|
||||
* The call is non-blocking.
|
||||
* </p>
|
||||
*
|
||||
* @param model the ollama model to ask the question to
|
||||
* @param prompt the prompt/question text
|
||||
* @return the ollama async result callback handle
|
||||
* <p>
|
||||
* <b>Example usage:</b>
|
||||
* </p>
|
||||
*
|
||||
* <pre>{@code
|
||||
* OllamaAsyncResultStreamer resultStreamer = ollamaAPI.generateAsync("gpt-oss:20b", "Who are you", false, true);
|
||||
* int pollIntervalMilliseconds = 1000;
|
||||
* while (true) {
|
||||
* String thinkingTokens = resultStreamer.getThinkingResponseStream().poll();
|
||||
* String responseTokens = resultStreamer.getResponseStream().poll();
|
||||
* System.out.print(thinkingTokens != null ? thinkingTokens.toUpperCase() : "");
|
||||
* System.out.print(responseTokens != null ? responseTokens.toLowerCase() : "");
|
||||
* Thread.sleep(pollIntervalMilliseconds);
|
||||
* if (!resultStreamer.isAlive())
|
||||
* break;
|
||||
* }
|
||||
* System.out.println("Complete thinking response: " + resultStreamer.getCompleteThinkingResponse());
|
||||
* System.out.println("Complete response: " + resultStreamer.getCompleteResponse());
|
||||
* }</pre>
|
||||
*
|
||||
* @param model the Ollama model to use for generating the response
|
||||
* @param prompt the prompt or question text to send to the model
|
||||
* @param raw if {@code true}, returns the raw response from the model
|
||||
* @param think if {@code true}, streams "thinking" tokens as well as response
|
||||
* tokens
|
||||
* @return an {@link OllamaAsyncResultStreamer} handle for polling and
|
||||
* retrieving streamed results
|
||||
*/
|
||||
public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw) {
|
||||
public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw, boolean think) {
|
||||
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
|
||||
ollamaRequestModel.setRaw(raw);
|
||||
ollamaRequestModel.setThink(think);
|
||||
URI uri = URI.create(this.host + "/api/generate");
|
||||
OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(
|
||||
getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
|
||||
@ -953,7 +1147,7 @@ public class OllamaAPI {
|
||||
}
|
||||
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images);
|
||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
|
||||
return generateSyncForOllamaRequestModel(ollamaRequestModel, null, streamHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1001,7 +1195,7 @@ public class OllamaAPI {
|
||||
}
|
||||
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images);
|
||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
|
||||
return generateSyncForOllamaRequestModel(ollamaRequestModel, null, streamHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1023,38 +1217,47 @@ public class OllamaAPI {
|
||||
/**
|
||||
* Synchronously generates a response using a list of image byte arrays.
|
||||
* <p>
|
||||
* This method encodes the provided byte arrays into Base64 and sends them to the Ollama server.
|
||||
* This method encodes the provided byte arrays into Base64 and sends them to
|
||||
* the Ollama server.
|
||||
*
|
||||
* @param model the Ollama model to use for generating the response
|
||||
* @param prompt the prompt or question text to send to the model
|
||||
* @param images the list of image data as byte arrays
|
||||
* @param options the Options object - <a href="https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More details on the options</a>
|
||||
* @param streamHandler optional callback that will be invoked with each streamed response; if null, streaming is disabled
|
||||
* @return OllamaResult containing the response text and the time taken for the response
|
||||
* @param options the Options object - <a href=
|
||||
* "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More
|
||||
* details on the options</a>
|
||||
* @param streamHandler optional callback that will be invoked with each
|
||||
* streamed response; if null, streaming is disabled
|
||||
* @return OllamaResult containing the response text and the time taken for the
|
||||
* response
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options,
|
||||
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
List<String> encodedImages = new ArrayList<>();
|
||||
for (byte[] image : images) {
|
||||
encodedImages.add(encodeByteArrayToBase64(image));
|
||||
}
|
||||
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, encodedImages);
|
||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
|
||||
return generateSyncForOllamaRequestModel(ollamaRequestModel, null, streamHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience method to call the Ollama API using image byte arrays without streaming responses.
|
||||
* Convenience method to call the Ollama API using image byte arrays without
|
||||
* streaming responses.
|
||||
* <p>
|
||||
* Uses {@link #generateWithImages(String, String, List, Options, OllamaStreamHandler)}
|
||||
* Uses
|
||||
* {@link #generateWithImages(String, String, List, Options, OllamaStreamHandler)}
|
||||
*
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options) throws OllamaBaseException, IOException, InterruptedException {
|
||||
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
return generateWithImages(model, prompt, images, options, null);
|
||||
}
|
||||
|
||||
@ -1069,10 +1272,12 @@ public class OllamaAPI {
|
||||
* history including the newly acquired assistant response.
|
||||
* @throws OllamaBaseException any response code than 200 has been returned
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network
|
||||
* @throws InterruptedException in case the server is not reachable or
|
||||
* network
|
||||
* issues happen
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws IOException if an I/O error occurs during the HTTP
|
||||
* request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws ToolInvocationException if the tool invocation fails
|
||||
*/
|
||||
@ -1092,16 +1297,18 @@ public class OllamaAPI {
|
||||
* @return {@link OllamaChatResult}
|
||||
* @throws OllamaBaseException any response code than 200 has been returned
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network
|
||||
* @throws InterruptedException in case the server is not reachable or
|
||||
* network
|
||||
* issues happen
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws IOException if an I/O error occurs during the HTTP
|
||||
* request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws ToolInvocationException if the tool invocation fails
|
||||
*/
|
||||
public OllamaChatResult chat(OllamaChatRequest request)
|
||||
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
|
||||
return chat(request, null);
|
||||
return chat(request, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1111,22 +1318,26 @@ public class OllamaAPI {
|
||||
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
|
||||
*
|
||||
* @param request request object to be sent to the server
|
||||
* @param streamHandler callback handler to handle the last message from stream
|
||||
* (caution: all previous tokens from stream will be
|
||||
* concatenated)
|
||||
* @param responseStreamHandler callback handler to handle the last message from
|
||||
* stream
|
||||
* @param thinkingStreamHandler callback handler to handle the last thinking
|
||||
* message from stream
|
||||
* @return {@link OllamaChatResult}
|
||||
* @throws OllamaBaseException any response code than 200 has been returned
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network
|
||||
* @throws InterruptedException in case the server is not reachable or
|
||||
* network
|
||||
* issues happen
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws IOException if an I/O error occurs during the HTTP
|
||||
* request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
* @throws ToolInvocationException if the tool invocation fails
|
||||
*/
|
||||
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler)
|
||||
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler thinkingStreamHandler,
|
||||
OllamaStreamHandler responseStreamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
|
||||
return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
|
||||
return chatStreaming(request, new OllamaChatStreamObserver(thinkingStreamHandler, responseStreamHandler));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1177,8 +1388,11 @@ public class OllamaAPI {
|
||||
}
|
||||
Map<String, Object> arguments = toolCall.getFunction().getArguments();
|
||||
Object res = toolFunction.apply(arguments);
|
||||
String argumentKeys = arguments.keySet().stream()
|
||||
.map(Object::toString)
|
||||
.collect(Collectors.joining(", "));
|
||||
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,
|
||||
"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]"));
|
||||
"[TOOL_RESULTS] " + toolName + "(" + argumentKeys + "): " + res + " [/TOOL_RESULTS]"));
|
||||
}
|
||||
|
||||
if (tokenHandler != null) {
|
||||
@ -1224,6 +1438,17 @@ public class OllamaAPI {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deregisters all tools from the tool registry.
|
||||
* This method removes all registered tools, effectively clearing the registry.
|
||||
*/
|
||||
public void deregisterTools() {
|
||||
toolRegistry.clear();
|
||||
if (this.verbose) {
|
||||
logger.debug("All tools have been deregistered.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers tools based on the annotations found on the methods of the caller's
|
||||
* class and its providers.
|
||||
@ -1380,9 +1605,11 @@ public class OllamaAPI {
|
||||
* the request will be streamed; otherwise, a regular synchronous request will
|
||||
* be made.
|
||||
*
|
||||
* @param ollamaRequestModel the request model containing necessary parameters
|
||||
* @param ollamaRequestModel the request model containing necessary
|
||||
* parameters
|
||||
* for the Ollama API request.
|
||||
* @param streamHandler the stream handler to process streaming responses,
|
||||
* @param responseStreamHandler the stream handler to process streaming
|
||||
* responses,
|
||||
* or null for non-streaming requests.
|
||||
* @return the result of the Ollama API request.
|
||||
* @throws OllamaBaseException if the request fails due to an issue with the
|
||||
@ -1392,13 +1619,14 @@ public class OllamaAPI {
|
||||
* @throws InterruptedException if the thread is interrupted during the request.
|
||||
*/
|
||||
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel,
|
||||
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds,
|
||||
verbose);
|
||||
OllamaResult result;
|
||||
if (streamHandler != null) {
|
||||
if (responseStreamHandler != null) {
|
||||
ollamaRequestModel.setStream(true);
|
||||
result = requestCaller.call(ollamaRequestModel, streamHandler);
|
||||
result = requestCaller.call(ollamaRequestModel, thinkingStreamHandler, responseStreamHandler);
|
||||
} else {
|
||||
result = requestCaller.callSync(ollamaRequestModel);
|
||||
}
|
||||
@ -1412,7 +1640,8 @@ public class OllamaAPI {
|
||||
* @return HttpRequest.Builder
|
||||
*/
|
||||
private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
|
||||
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json")
|
||||
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri)
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.timeout(Duration.ofSeconds(requestTimeoutSeconds));
|
||||
if (isBasicAuthCredentialsSet()) {
|
||||
requestBuilder.header("Authorization", auth.getAuthHeaderValue());
|
||||
|
@ -3,12 +3,8 @@ package io.github.ollama4j.impl;
|
||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||
|
||||
public class ConsoleOutputStreamHandler implements OllamaStreamHandler {
|
||||
private final StringBuffer response = new StringBuffer();
|
||||
|
||||
@Override
|
||||
public void accept(String message) {
|
||||
String substr = message.substring(response.length());
|
||||
response.append(substr);
|
||||
System.out.print(substr);
|
||||
System.out.print(message);
|
||||
}
|
||||
}
|
||||
|
@ -1,21 +1,15 @@
|
||||
package io.github.ollama4j.models.chat;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
|
||||
import io.github.ollama4j.utils.FileToBase64Serializer;
|
||||
import lombok.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
/**
|
||||
* Defines a single Message to be used inside a chat request against the ollama /api/chat endpoint.
|
||||
@ -35,6 +29,8 @@ public class OllamaChatMessage {
|
||||
@NonNull
|
||||
private String content;
|
||||
|
||||
private String thinking;
|
||||
|
||||
private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
|
||||
|
||||
@JsonSerialize(using = FileToBase64Serializer.class)
|
||||
|
@ -1,14 +1,13 @@
|
||||
package io.github.ollama4j.models.chat;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import io.github.ollama4j.models.request.OllamaCommonRequest;
|
||||
import io.github.ollama4j.tools.Tools;
|
||||
import io.github.ollama4j.utils.OllamaRequestBody;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Defines a Request to use against the ollama /api/chat endpoint.
|
||||
*
|
||||
@ -24,11 +23,15 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ
|
||||
|
||||
private List<Tools.PromptFuncDefinition> tools;
|
||||
|
||||
public OllamaChatRequest() {}
|
||||
private boolean think;
|
||||
|
||||
public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
|
||||
public OllamaChatRequest() {
|
||||
}
|
||||
|
||||
public OllamaChatRequest(String model, boolean think, List<OllamaChatMessage> messages) {
|
||||
this.model = model;
|
||||
this.messages = messages;
|
||||
this.think = think;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -22,7 +22,7 @@ public class OllamaChatRequestBuilder {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class);
|
||||
|
||||
private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages) {
|
||||
request = new OllamaChatRequest(model, messages);
|
||||
request = new OllamaChatRequest(model, false, messages);
|
||||
}
|
||||
|
||||
private OllamaChatRequest request;
|
||||
@ -36,13 +36,19 @@ public class OllamaChatRequestBuilder {
|
||||
}
|
||||
|
||||
public void reset() {
|
||||
request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
|
||||
request = new OllamaChatRequest(request.getModel(), request.isThink(), new ArrayList<>());
|
||||
}
|
||||
|
||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content) {
|
||||
return withMessage(role, content, Collections.emptyList());
|
||||
}
|
||||
|
||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls) {
|
||||
List<OllamaChatMessage> messages = this.request.getMessages();
|
||||
messages.add(new OllamaChatMessage(role, content, null, toolCalls, null));
|
||||
return this;
|
||||
}
|
||||
|
||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls, List<File> images) {
|
||||
List<OllamaChatMessage> messages = this.request.getMessages();
|
||||
|
||||
@ -55,7 +61,7 @@ public class OllamaChatRequestBuilder {
|
||||
}
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
|
||||
messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages));
|
||||
return this;
|
||||
}
|
||||
|
||||
@ -75,7 +81,7 @@ public class OllamaChatRequestBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
|
||||
messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages));
|
||||
return this;
|
||||
}
|
||||
|
||||
@ -108,4 +114,8 @@ public class OllamaChatRequestBuilder {
|
||||
return this;
|
||||
}
|
||||
|
||||
public OllamaChatRequestBuilder withThinking(boolean think) {
|
||||
this.request.setThink(think);
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
package io.github.ollama4j.models.chat;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
/**
|
||||
|
@ -6,14 +6,46 @@ import lombok.RequiredArgsConstructor;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class OllamaChatStreamObserver implements OllamaTokenHandler {
|
||||
private final OllamaStreamHandler streamHandler;
|
||||
private final OllamaStreamHandler thinkingStreamHandler;
|
||||
private final OllamaStreamHandler responseStreamHandler;
|
||||
|
||||
private String message = "";
|
||||
|
||||
@Override
|
||||
public void accept(OllamaChatResponseModel token) {
|
||||
if (streamHandler != null) {
|
||||
message += token.getMessage().getContent();
|
||||
streamHandler.accept(message);
|
||||
if (responseStreamHandler == null || token == null || token.getMessage() == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
String thinking = token.getMessage().getThinking();
|
||||
String content = token.getMessage().getContent();
|
||||
|
||||
boolean hasThinking = thinking != null && !thinking.isEmpty();
|
||||
boolean hasContent = !content.isEmpty();
|
||||
|
||||
// if (hasThinking && !hasContent) {
|
||||
//// message += thinking;
|
||||
// message = thinking;
|
||||
// } else {
|
||||
//// message += content;
|
||||
// message = content;
|
||||
// }
|
||||
//
|
||||
// responseStreamHandler.accept(message);
|
||||
|
||||
|
||||
if (!hasContent && hasThinking && thinkingStreamHandler != null) {
|
||||
// message = message + thinking;
|
||||
|
||||
// use only new tokens received, instead of appending the tokens to the previous
|
||||
// ones and sending the full string again
|
||||
thinkingStreamHandler.accept(thinking);
|
||||
} else if (hasContent && responseStreamHandler != null) {
|
||||
// message = message + response;
|
||||
|
||||
// use only new tokens received, instead of appending the tokens to the previous
|
||||
// ones and sending the full string again
|
||||
responseStreamHandler.accept(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
package io.github.ollama4j.models.embeddings;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@SuppressWarnings("unused")
|
||||
@Data
|
||||
|
@ -1,7 +1,5 @@
|
||||
package io.github.ollama4j.models.embeddings;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
import java.util.Map;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import lombok.Data;
|
||||
@ -9,6 +7,10 @@ import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
@Data
|
||||
@RequiredArgsConstructor
|
||||
@NoArgsConstructor
|
||||
|
@ -3,12 +3,11 @@ package io.github.ollama4j.models.generate;
|
||||
|
||||
import io.github.ollama4j.models.request.OllamaCommonRequest;
|
||||
import io.github.ollama4j.utils.OllamaRequestBody;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public class OllamaGenerateRequest extends OllamaCommonRequest implements OllamaRequestBody{
|
||||
@ -19,6 +18,7 @@ public class OllamaGenerateRequest extends OllamaCommonRequest implements Ollama
|
||||
private String system;
|
||||
private String context;
|
||||
private boolean raw;
|
||||
private boolean think;
|
||||
|
||||
public OllamaGenerateRequest() {
|
||||
}
|
||||
|
@ -2,9 +2,9 @@ package io.github.ollama4j.models.generate;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
@ -12,12 +12,14 @@ public class OllamaGenerateResponseModel {
|
||||
private String model;
|
||||
private @JsonProperty("created_at") String createdAt;
|
||||
private String response;
|
||||
private String thinking;
|
||||
private boolean done;
|
||||
private @JsonProperty("done_reason") String doneReason;
|
||||
private List<Integer> context;
|
||||
private @JsonProperty("total_duration") Long totalDuration;
|
||||
private @JsonProperty("load_duration") Long loadDuration;
|
||||
private @JsonProperty("prompt_eval_duration") Long promptEvalDuration;
|
||||
private @JsonProperty("eval_duration") Long evalDuration;
|
||||
private @JsonProperty("prompt_eval_count") Integer promptEvalCount;
|
||||
private @JsonProperty("prompt_eval_duration") Long promptEvalDuration;
|
||||
private @JsonProperty("eval_count") Integer evalCount;
|
||||
private @JsonProperty("eval_duration") Long evalDuration;
|
||||
}
|
||||
|
@ -5,14 +5,16 @@ import java.util.List;
|
||||
|
||||
public class OllamaGenerateStreamObserver {
|
||||
|
||||
private OllamaStreamHandler streamHandler;
|
||||
private final OllamaStreamHandler thinkingStreamHandler;
|
||||
private final OllamaStreamHandler responseStreamHandler;
|
||||
|
||||
private List<OllamaGenerateResponseModel> responseParts = new ArrayList<>();
|
||||
private final List<OllamaGenerateResponseModel> responseParts = new ArrayList<>();
|
||||
|
||||
private String message = "";
|
||||
|
||||
public OllamaGenerateStreamObserver(OllamaStreamHandler streamHandler) {
|
||||
this.streamHandler = streamHandler;
|
||||
public OllamaGenerateStreamObserver(OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) {
|
||||
this.responseStreamHandler = responseStreamHandler;
|
||||
this.thinkingStreamHandler = thinkingStreamHandler;
|
||||
}
|
||||
|
||||
public void notify(OllamaGenerateResponseModel currentResponsePart) {
|
||||
@ -21,9 +23,24 @@ public class OllamaGenerateStreamObserver {
|
||||
}
|
||||
|
||||
protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart) {
|
||||
message = message + currentResponsePart.getResponse();
|
||||
streamHandler.accept(message);
|
||||
}
|
||||
String response = currentResponsePart.getResponse();
|
||||
String thinking = currentResponsePart.getThinking();
|
||||
|
||||
boolean hasResponse = response != null && !response.isEmpty();
|
||||
boolean hasThinking = thinking != null && !thinking.isEmpty();
|
||||
|
||||
if (!hasResponse && hasThinking && thinkingStreamHandler != null) {
|
||||
// message = message + thinking;
|
||||
|
||||
// use only new tokens received, instead of appending the tokens to the previous
|
||||
// ones and sending the full string again
|
||||
thinkingStreamHandler.accept(thinking);
|
||||
} else if (hasResponse && responseStreamHandler != null) {
|
||||
// message = message + response;
|
||||
|
||||
// use only new tokens received, instead of appending the tokens to the previous
|
||||
// ones and sending the full string again
|
||||
responseStreamHandler.accept(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,13 +1,14 @@
|
||||
package io.github.ollama4j.models.request;
|
||||
|
||||
import java.util.Base64;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
import java.util.Base64;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class BasicAuth extends Auth {
|
||||
private String username;
|
||||
private String password;
|
||||
|
@ -2,9 +2,11 @@ package io.github.ollama4j.models.request;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class BearerAuth extends Auth {
|
||||
private String bearerToken;
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
package io.github.ollama4j.models.request;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class CustomModelFileContentsRequest {
|
||||
|
@ -1,11 +1,11 @@
|
||||
package io.github.ollama4j.models.request;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class CustomModelFilePathRequest {
|
||||
|
@ -1,17 +1,15 @@
|
||||
package io.github.ollama4j.models.request;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.Data;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
|
@ -1,11 +1,11 @@
|
||||
package io.github.ollama4j.models.request;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class ModelRequest {
|
||||
|
@ -24,6 +24,7 @@ import java.util.List;
|
||||
/**
|
||||
* Specialization class for requests
|
||||
*/
|
||||
@SuppressWarnings("resource")
|
||||
public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
|
||||
@ -51,14 +52,19 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
||||
* @return TRUE, if ollama-Response has 'done' state
|
||||
*/
|
||||
@Override
|
||||
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
|
||||
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) {
|
||||
try {
|
||||
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
|
||||
// it seems that under heavy load ollama responds with an empty chat message part in the streamed response
|
||||
// thus, we null check the message and hope that the next streamed response has some message content again
|
||||
OllamaChatMessage message = ollamaResponseModel.getMessage();
|
||||
if (message != null) {
|
||||
if (message.getThinking() != null) {
|
||||
thinkingBuffer.append(message.getThinking());
|
||||
}
|
||||
else {
|
||||
responseBuffer.append(message.getContent());
|
||||
}
|
||||
if (tokenHandler != null) {
|
||||
tokenHandler.accept(ollamaResponseModel);
|
||||
}
|
||||
@ -85,13 +91,14 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
||||
.POST(
|
||||
body.getBodyPublisher());
|
||||
HttpRequest request = requestBuilder.build();
|
||||
if (isVerbose()) LOG.info("Asking model: " + body);
|
||||
if (isVerbose()) LOG.info("Asking model: {}", body);
|
||||
HttpResponse<InputStream> response =
|
||||
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
|
||||
int statusCode = response.statusCode();
|
||||
InputStream responseBodyStream = response.body();
|
||||
StringBuilder responseBuffer = new StringBuilder();
|
||||
StringBuilder thinkingBuffer = new StringBuilder();
|
||||
OllamaChatResponseModel ollamaChatResponseModel = null;
|
||||
List<OllamaChatToolCalls> wantedToolsForStream = null;
|
||||
try (BufferedReader reader =
|
||||
@ -115,14 +122,20 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
||||
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
|
||||
OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else if (statusCode == 500) {
|
||||
LOG.warn("Status code: 500 (Internal Server Error)");
|
||||
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
|
||||
OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else {
|
||||
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
||||
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer, thinkingBuffer);
|
||||
ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
|
||||
if (body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null) {
|
||||
wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
|
||||
}
|
||||
if (finished && body.stream) {
|
||||
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
|
||||
ollamaChatResponseModel.getMessage().setThinking(thinkingBuffer.toString());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -1,15 +1,15 @@
|
||||
package io.github.ollama4j.models.request;
|
||||
|
||||
import java.util.Map;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
import io.github.ollama4j.utils.BooleanToJsonFormatFlagSerializer;
|
||||
import io.github.ollama4j.utils.Utils;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public abstract class OllamaCommonRequest {
|
||||
|
@ -1,15 +1,15 @@
|
||||
package io.github.ollama4j.models.request;
|
||||
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.utils.Constants;
|
||||
import lombok.Getter;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.time.Duration;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import lombok.Getter;
|
||||
|
||||
/**
|
||||
* Abstract helperclass to call the ollama api server.
|
||||
*/
|
||||
@ -32,7 +32,7 @@ public abstract class OllamaEndpointCaller {
|
||||
|
||||
protected abstract String getEndpointSuffix();
|
||||
|
||||
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
|
||||
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer);
|
||||
|
||||
|
||||
/**
|
||||
@ -44,7 +44,7 @@ public abstract class OllamaEndpointCaller {
|
||||
protected HttpRequest.Builder getRequestBuilderDefault(URI uri) {
|
||||
HttpRequest.Builder requestBuilder =
|
||||
HttpRequest.newBuilder(uri)
|
||||
.header("Content-Type", "application/json")
|
||||
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
|
||||
.timeout(Duration.ofSeconds(this.requestTimeoutSeconds));
|
||||
if (isAuthCredentialsSet()) {
|
||||
requestBuilder.header("Authorization", this.auth.getAuthHeaderValue());
|
||||
|
@ -2,11 +2,11 @@ package io.github.ollama4j.models.request;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
||||
import io.github.ollama4j.models.response.OllamaResult;
|
||||
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
|
||||
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
|
||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
||||
import io.github.ollama4j.models.response.OllamaResult;
|
||||
import io.github.ollama4j.utils.OllamaRequestBody;
|
||||
import io.github.ollama4j.utils.Utils;
|
||||
import org.slf4j.Logger;
|
||||
@ -22,11 +22,12 @@ import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
@SuppressWarnings("resource")
|
||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
|
||||
|
||||
private OllamaGenerateStreamObserver streamObserver;
|
||||
private OllamaGenerateStreamObserver responseStreamObserver;
|
||||
|
||||
public OllamaGenerateEndpointCaller(String host, Auth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
||||
super(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||
@ -38,12 +39,17 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
|
||||
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) {
|
||||
try {
|
||||
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
|
||||
if (ollamaResponseModel.getResponse() != null) {
|
||||
responseBuffer.append(ollamaResponseModel.getResponse());
|
||||
if (streamObserver != null) {
|
||||
streamObserver.notify(ollamaResponseModel);
|
||||
}
|
||||
if (ollamaResponseModel.getThinking() != null) {
|
||||
thinkingBuffer.append(ollamaResponseModel.getThinking());
|
||||
}
|
||||
if (responseStreamObserver != null) {
|
||||
responseStreamObserver.notify(ollamaResponseModel);
|
||||
}
|
||||
return ollamaResponseModel.isDone();
|
||||
} catch (JsonProcessingException e) {
|
||||
@ -52,9 +58,8 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
||||
}
|
||||
}
|
||||
|
||||
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException {
|
||||
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
|
||||
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
responseStreamObserver = new OllamaGenerateStreamObserver(thinkingStreamHandler, responseStreamHandler);
|
||||
return callSync(body);
|
||||
}
|
||||
|
||||
@ -67,46 +72,41 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||
*/
|
||||
@SuppressWarnings("DuplicatedCode")
|
||||
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
|
||||
// Create Request
|
||||
long startTime = System.currentTimeMillis();
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
URI uri = URI.create(getHost() + getEndpointSuffix());
|
||||
HttpRequest.Builder requestBuilder =
|
||||
getRequestBuilderDefault(uri)
|
||||
.POST(
|
||||
body.getBodyPublisher());
|
||||
HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).POST(body.getBodyPublisher());
|
||||
HttpRequest request = requestBuilder.build();
|
||||
if (isVerbose()) LOG.info("Asking model: " + body.toString());
|
||||
HttpResponse<InputStream> response =
|
||||
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
if (isVerbose()) LOG.info("Asking model: {}", body);
|
||||
HttpResponse<InputStream> response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
|
||||
int statusCode = response.statusCode();
|
||||
InputStream responseBodyStream = response.body();
|
||||
StringBuilder responseBuffer = new StringBuilder();
|
||||
try (BufferedReader reader =
|
||||
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||
StringBuilder thinkingBuffer = new StringBuilder();
|
||||
OllamaGenerateResponseModel ollamaGenerateResponseModel = null;
|
||||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (statusCode == 404) {
|
||||
LOG.warn("Status code: 404 (Not Found)");
|
||||
OllamaErrorResponse ollamaResponseModel =
|
||||
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
||||
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else if (statusCode == 401) {
|
||||
LOG.warn("Status code: 401 (Unauthorized)");
|
||||
OllamaErrorResponse ollamaResponseModel =
|
||||
Utils.getObjectMapper()
|
||||
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
|
||||
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else if (statusCode == 400) {
|
||||
LOG.warn("Status code: 400 (Bad Request)");
|
||||
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
|
||||
OllamaErrorResponse.class);
|
||||
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else {
|
||||
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
||||
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer, thinkingBuffer);
|
||||
if (finished) {
|
||||
ollamaGenerateResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -114,13 +114,25 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
||||
}
|
||||
|
||||
if (statusCode != 200) {
|
||||
LOG.error("Status code " + statusCode);
|
||||
LOG.error("Status code: {}", statusCode);
|
||||
throw new OllamaBaseException(responseBuffer.toString());
|
||||
} else {
|
||||
long endTime = System.currentTimeMillis();
|
||||
OllamaResult ollamaResult =
|
||||
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
|
||||
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
|
||||
OllamaResult ollamaResult = new OllamaResult(responseBuffer.toString(), thinkingBuffer.toString(), endTime - startTime, statusCode);
|
||||
|
||||
ollamaResult.setModel(ollamaGenerateResponseModel.getModel());
|
||||
ollamaResult.setCreatedAt(ollamaGenerateResponseModel.getCreatedAt());
|
||||
ollamaResult.setDone(ollamaGenerateResponseModel.isDone());
|
||||
ollamaResult.setDoneReason(ollamaGenerateResponseModel.getDoneReason());
|
||||
ollamaResult.setContext(ollamaGenerateResponseModel.getContext());
|
||||
ollamaResult.setTotalDuration(ollamaGenerateResponseModel.getTotalDuration());
|
||||
ollamaResult.setLoadDuration(ollamaGenerateResponseModel.getLoadDuration());
|
||||
ollamaResult.setPromptEvalCount(ollamaGenerateResponseModel.getPromptEvalCount());
|
||||
ollamaResult.setPromptEvalDuration(ollamaGenerateResponseModel.getPromptEvalDuration());
|
||||
ollamaResult.setEvalCount(ollamaGenerateResponseModel.getEvalCount());
|
||||
ollamaResult.setEvalDuration(ollamaGenerateResponseModel.getEvalDuration());
|
||||
|
||||
if (isVerbose()) LOG.info("Model response: {}", ollamaResult);
|
||||
return ollamaResult;
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,10 @@
|
||||
package io.github.ollama4j.models.response;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class LibraryModel {
|
||||
|
||||
|
@ -2,8 +2,6 @@ package io.github.ollama4j.models.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class LibraryModelTag {
|
||||
private String name;
|
||||
|
@ -1,9 +1,9 @@
|
||||
package io.github.ollama4j.models.response;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ListModelsResponse {
|
||||
private List<Model> models;
|
||||
|
@ -1,13 +1,13 @@
|
||||
package io.github.ollama4j.models.response;
|
||||
|
||||
import java.time.OffsetDateTime;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import io.github.ollama4j.utils.Utils;
|
||||
import lombok.Data;
|
||||
|
||||
import java.time.OffsetDateTime;
|
||||
|
||||
@Data
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
public class Model {
|
||||
|
@ -3,6 +3,7 @@ package io.github.ollama4j.models.response;
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
|
||||
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
|
||||
import io.github.ollama4j.utils.Constants;
|
||||
import io.github.ollama4j.utils.Utils;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
@ -25,8 +26,10 @@ import java.time.Duration;
|
||||
public class OllamaAsyncResultStreamer extends Thread {
|
||||
private final HttpRequest.Builder requestBuilder;
|
||||
private final OllamaGenerateRequest ollamaRequestModel;
|
||||
private final OllamaResultStream stream = new OllamaResultStream();
|
||||
private final OllamaResultStream thinkingResponseStream = new OllamaResultStream();
|
||||
private final OllamaResultStream responseStream = new OllamaResultStream();
|
||||
private String completeResponse;
|
||||
private String completeThinkingResponse;
|
||||
|
||||
|
||||
/**
|
||||
@ -53,14 +56,11 @@ public class OllamaAsyncResultStreamer extends Thread {
|
||||
@Getter
|
||||
private long responseTime = 0;
|
||||
|
||||
public OllamaAsyncResultStreamer(
|
||||
HttpRequest.Builder requestBuilder,
|
||||
OllamaGenerateRequest ollamaRequestModel,
|
||||
long requestTimeoutSeconds) {
|
||||
public OllamaAsyncResultStreamer(HttpRequest.Builder requestBuilder, OllamaGenerateRequest ollamaRequestModel, long requestTimeoutSeconds) {
|
||||
this.requestBuilder = requestBuilder;
|
||||
this.ollamaRequestModel = ollamaRequestModel;
|
||||
this.completeResponse = "";
|
||||
this.stream.add("");
|
||||
this.responseStream.add("");
|
||||
this.requestTimeoutSeconds = requestTimeoutSeconds;
|
||||
}
|
||||
|
||||
@ -68,47 +68,63 @@ public class OllamaAsyncResultStreamer extends Thread {
|
||||
public void run() {
|
||||
ollamaRequestModel.setStream(true);
|
||||
HttpClient httpClient = HttpClient.newHttpClient();
|
||||
try {
|
||||
long startTime = System.currentTimeMillis();
|
||||
HttpRequest request =
|
||||
requestBuilder
|
||||
.POST(
|
||||
HttpRequest.BodyPublishers.ofString(
|
||||
Utils.getObjectMapper().writeValueAsString(ollamaRequestModel)))
|
||||
.header("Content-Type", "application/json")
|
||||
.timeout(Duration.ofSeconds(requestTimeoutSeconds))
|
||||
.build();
|
||||
HttpResponse<InputStream> response =
|
||||
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
try {
|
||||
HttpRequest request = requestBuilder.POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).timeout(Duration.ofSeconds(requestTimeoutSeconds)).build();
|
||||
HttpResponse<InputStream> response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
int statusCode = response.statusCode();
|
||||
this.httpStatusCode = statusCode;
|
||||
|
||||
InputStream responseBodyStream = response.body();
|
||||
try (BufferedReader reader =
|
||||
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||
BufferedReader reader = null;
|
||||
try {
|
||||
reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8));
|
||||
String line;
|
||||
StringBuilder thinkingBuffer = new StringBuilder();
|
||||
StringBuilder responseBuffer = new StringBuilder();
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (statusCode == 404) {
|
||||
OllamaErrorResponse ollamaResponseModel =
|
||||
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
||||
stream.add(ollamaResponseModel.getError());
|
||||
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
||||
responseStream.add(ollamaResponseModel.getError());
|
||||
responseBuffer.append(ollamaResponseModel.getError());
|
||||
} else {
|
||||
OllamaGenerateResponseModel ollamaResponseModel =
|
||||
Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
|
||||
String res = ollamaResponseModel.getResponse();
|
||||
stream.add(res);
|
||||
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
|
||||
String thinkingTokens = ollamaResponseModel.getThinking();
|
||||
String responseTokens = ollamaResponseModel.getResponse();
|
||||
if (thinkingTokens == null) {
|
||||
thinkingTokens = "";
|
||||
}
|
||||
if (responseTokens == null) {
|
||||
responseTokens = "";
|
||||
}
|
||||
thinkingResponseStream.add(thinkingTokens);
|
||||
responseStream.add(responseTokens);
|
||||
if (!ollamaResponseModel.isDone()) {
|
||||
responseBuffer.append(res);
|
||||
responseBuffer.append(responseTokens);
|
||||
thinkingBuffer.append(thinkingTokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.succeeded = true;
|
||||
this.completeThinkingResponse = thinkingBuffer.toString();
|
||||
this.completeResponse = responseBuffer.toString();
|
||||
long endTime = System.currentTimeMillis();
|
||||
responseTime = endTime - startTime;
|
||||
} finally {
|
||||
if (reader != null) {
|
||||
try {
|
||||
reader.close();
|
||||
} catch (IOException e) {
|
||||
// Optionally log or handle
|
||||
}
|
||||
}
|
||||
if (responseBodyStream != null) {
|
||||
try {
|
||||
responseBodyStream.close();
|
||||
} catch (IOException e) {
|
||||
// Optionally log or handle
|
||||
}
|
||||
}
|
||||
}
|
||||
if (statusCode != 200) {
|
||||
throw new OllamaBaseException(this.completeResponse);
|
||||
|
@ -3,47 +3,55 @@ package io.github.ollama4j.models.response;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/** The type Ollama result. */
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
/**
|
||||
* The type Ollama result.
|
||||
*/
|
||||
@Getter
|
||||
@SuppressWarnings("unused")
|
||||
@Data
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
public class OllamaResult {
|
||||
/**
|
||||
* -- GETTER --
|
||||
* Get the completion/response text
|
||||
*
|
||||
* @return String completion/response text
|
||||
*/
|
||||
private final String response;
|
||||
|
||||
/**
|
||||
* -- GETTER --
|
||||
* Get the thinking text (if available)
|
||||
*/
|
||||
private final String thinking;
|
||||
/**
|
||||
* Get the response status code.
|
||||
*
|
||||
* @return int - response status code
|
||||
*/
|
||||
private int httpStatusCode;
|
||||
|
||||
/**
|
||||
* -- GETTER --
|
||||
* Get the response time in milliseconds.
|
||||
*
|
||||
* @return long - response time in milliseconds
|
||||
*/
|
||||
private long responseTime = 0;
|
||||
|
||||
public OllamaResult(String response, long responseTime, int httpStatusCode) {
|
||||
private String model;
|
||||
private String createdAt;
|
||||
private boolean done;
|
||||
private String doneReason;
|
||||
private List<Integer> context;
|
||||
private Long totalDuration;
|
||||
private Long loadDuration;
|
||||
private Integer promptEvalCount;
|
||||
private Long promptEvalDuration;
|
||||
private Integer evalCount;
|
||||
private Long evalDuration;
|
||||
|
||||
public OllamaResult(String response, String thinking, long responseTime, int httpStatusCode) {
|
||||
this.response = response;
|
||||
this.thinking = thinking;
|
||||
this.responseTime = responseTime;
|
||||
this.httpStatusCode = httpStatusCode;
|
||||
}
|
||||
@ -53,8 +61,20 @@ public class OllamaResult {
|
||||
try {
|
||||
Map<String, Object> responseMap = new HashMap<>();
|
||||
responseMap.put("response", this.response);
|
||||
responseMap.put("thinking", this.thinking);
|
||||
responseMap.put("httpStatusCode", this.httpStatusCode);
|
||||
responseMap.put("responseTime", this.responseTime);
|
||||
responseMap.put("model", this.model);
|
||||
responseMap.put("createdAt", this.createdAt);
|
||||
responseMap.put("done", this.done);
|
||||
responseMap.put("doneReason", this.doneReason);
|
||||
responseMap.put("context", this.context);
|
||||
responseMap.put("totalDuration", this.totalDuration);
|
||||
responseMap.put("loadDuration", this.loadDuration);
|
||||
responseMap.put("promptEvalCount", this.promptEvalCount);
|
||||
responseMap.put("promptEvalDuration", this.promptEvalDuration);
|
||||
responseMap.put("evalCount", this.evalCount);
|
||||
responseMap.put("evalDuration", this.evalDuration);
|
||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
|
@ -1,19 +1,18 @@
|
||||
package io.github.ollama4j.models.response;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
@Getter
|
||||
@SuppressWarnings("unused")
|
||||
@Data
|
||||
@ -21,13 +20,22 @@ import lombok.NoArgsConstructor;
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
public class OllamaStructuredResult {
|
||||
private String response;
|
||||
|
||||
private String thinking;
|
||||
private int httpStatusCode;
|
||||
|
||||
private long responseTime = 0;
|
||||
|
||||
private String model;
|
||||
|
||||
private @JsonProperty("created_at") String createdAt;
|
||||
private boolean done;
|
||||
private @JsonProperty("done_reason") String doneReason;
|
||||
private List<Integer> context;
|
||||
private @JsonProperty("total_duration") Long totalDuration;
|
||||
private @JsonProperty("load_duration") Long loadDuration;
|
||||
private @JsonProperty("prompt_eval_count") Integer promptEvalCount;
|
||||
private @JsonProperty("prompt_eval_duration") Long promptEvalDuration;
|
||||
private @JsonProperty("eval_count") Integer evalCount;
|
||||
private @JsonProperty("eval_duration") Long evalDuration;
|
||||
|
||||
public OllamaStructuredResult(String response, long responseTime, int httpStatusCode) {
|
||||
this.response = response;
|
||||
this.responseTime = responseTime;
|
||||
|
@ -2,8 +2,6 @@ package io.github.ollama4j.models.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class OllamaVersion {
|
||||
private String version;
|
||||
|
@ -19,4 +19,11 @@ public class ToolRegistry {
|
||||
public Collection<Tools.ToolSpecification> getRegisteredSpecs() {
|
||||
return tools.values();
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes all registered tools from the registry.
|
||||
*/
|
||||
public void clear() {
|
||||
tools.clear();
|
||||
}
|
||||
}
|
||||
|
@ -1,53 +1,19 @@
|
||||
package io.github.ollama4j.tools.sampletools;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.util.Map;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import io.github.ollama4j.tools.Tools;
|
||||
|
||||
public class WeatherTool {
|
||||
private String openWeatherMapAPIKey = null;
|
||||
import java.util.Map;
|
||||
|
||||
public WeatherTool(String openWeatherMapAPIKey) {
|
||||
this.openWeatherMapAPIKey = openWeatherMapAPIKey;
|
||||
@SuppressWarnings("resource")
|
||||
public class WeatherTool {
|
||||
private String paramCityName = "cityName";
|
||||
|
||||
public WeatherTool() {
|
||||
}
|
||||
|
||||
public String getCurrentWeather(Map<String, Object> arguments) {
|
||||
String city = (String) arguments.get("cityName");
|
||||
System.out.println("Finding weather for city: " + city);
|
||||
|
||||
String url = String.format("https://api.openweathermap.org/data/2.5/weather?q=%s&appid=%s&units=metric",
|
||||
city,
|
||||
this.openWeatherMapAPIKey);
|
||||
|
||||
HttpClient client = HttpClient.newHttpClient();
|
||||
HttpRequest request = HttpRequest.newBuilder()
|
||||
.uri(URI.create(url))
|
||||
.build();
|
||||
try {
|
||||
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
|
||||
if (response.statusCode() == 200) {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
JsonNode root = mapper.readTree(response.body());
|
||||
JsonNode main = root.path("main");
|
||||
double temperature = main.path("temp").asDouble();
|
||||
String description = root.path("weather").get(0).path("description").asText();
|
||||
return String.format("Weather in %s: %.1f°C, %s", city, temperature, description);
|
||||
} else {
|
||||
return "Could not retrieve weather data for " + city + ". Status code: "
|
||||
+ response.statusCode();
|
||||
}
|
||||
} catch (IOException | InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
return "Error retrieving weather data: " + e.getMessage();
|
||||
}
|
||||
String city = (String) arguments.get(paramCityName);
|
||||
return "It is sunny in " + city;
|
||||
}
|
||||
|
||||
public Tools.ToolSpecification getSpecification() {
|
||||
@ -70,7 +36,7 @@ public class WeatherTool {
|
||||
.type("object")
|
||||
.properties(
|
||||
Map.of(
|
||||
"cityName",
|
||||
paramCityName,
|
||||
Tools.PromptFuncDefinition.Property
|
||||
.builder()
|
||||
.type("string")
|
||||
@ -79,7 +45,7 @@ public class WeatherTool {
|
||||
.required(true)
|
||||
.build()))
|
||||
.required(java.util.List
|
||||
.of("cityName"))
|
||||
.of(paramCityName))
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
|
@ -1,11 +1,11 @@
|
||||
package io.github.ollama4j.utils;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonGenerator;
|
||||
import com.fasterxml.jackson.databind.JsonSerializer;
|
||||
import com.fasterxml.jackson.databind.SerializerProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class BooleanToJsonFormatFlagSerializer extends JsonSerializer<Boolean>{
|
||||
|
||||
@Override
|
||||
|
14
src/main/java/io/github/ollama4j/utils/Constants.java
Normal file
14
src/main/java/io/github/ollama4j/utils/Constants.java
Normal file
@ -0,0 +1,14 @@
|
||||
package io.github.ollama4j.utils;
|
||||
|
||||
public final class Constants {
|
||||
public static final class HttpConstants {
|
||||
private HttpConstants() {
|
||||
}
|
||||
|
||||
public static final String APPLICATION_JSON = "application/json";
|
||||
public static final String APPLICATION_XML = "application/xml";
|
||||
public static final String TEXT_PLAIN = "text/plain";
|
||||
public static final String HEADER_KEY_CONTENT_TYPE = "Content-Type";
|
||||
public static final String HEADER_KEY_ACCEPT = "Accept";
|
||||
}
|
||||
}
|
@ -1,13 +1,13 @@
|
||||
package io.github.ollama4j.utils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Base64;
|
||||
import java.util.Collection;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonGenerator;
|
||||
import com.fasterxml.jackson.databind.JsonSerializer;
|
||||
import com.fasterxml.jackson.databind.SerializerProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Base64;
|
||||
import java.util.Collection;
|
||||
|
||||
public class FileToBase64Serializer extends JsonSerializer<Collection<byte[]>> {
|
||||
|
||||
@Override
|
||||
|
@ -1,11 +1,11 @@
|
||||
package io.github.ollama4j.utils;
|
||||
|
||||
import java.net.http.HttpRequest.BodyPublisher;
|
||||
import java.net.http.HttpRequest.BodyPublishers;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
|
||||
import java.net.http.HttpRequest.BodyPublisher;
|
||||
import java.net.http.HttpRequest.BodyPublishers;
|
||||
|
||||
/**
|
||||
* Interface to represent a OllamaRequest as HTTP-Request Body via {@link BodyPublishers}.
|
||||
*/
|
||||
|
@ -1,8 +1,9 @@
|
||||
package io.github.ollama4j.utils;
|
||||
|
||||
import java.util.Map;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/** Class for options for Ollama model. */
|
||||
@Data
|
||||
public class Options {
|
||||
|
@ -1,6 +1,5 @@
|
||||
package io.github.ollama4j.utils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
|
||||
/** Builder class for creating options for Ollama model. */
|
||||
|
@ -1,25 +0,0 @@
|
||||
package io.github.ollama4j.utils;
|
||||
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.Scanner;
|
||||
|
||||
public class SamplePrompts {
|
||||
public static String getSampleDatabasePromptWithQuestion(String question) throws Exception {
|
||||
ClassLoader classLoader = OllamaAPI.class.getClassLoader();
|
||||
InputStream inputStream = classLoader.getResourceAsStream("sample-db-prompt-template.txt");
|
||||
if (inputStream != null) {
|
||||
Scanner scanner = new Scanner(inputStream);
|
||||
StringBuilder stringBuffer = new StringBuilder();
|
||||
while (scanner.hasNextLine()) {
|
||||
stringBuffer.append(scanner.nextLine()).append("\n");
|
||||
}
|
||||
scanner.close();
|
||||
return stringBuffer.toString().replaceAll("<question>", question);
|
||||
} else {
|
||||
throw new Exception("Sample database question file not found.");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -1,14 +1,16 @@
|
||||
package io.github.ollama4j.utils;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.net.URL;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
|
||||
import java.util.Objects;
|
||||
|
||||
public class Utils {
|
||||
|
||||
@ -35,4 +37,9 @@ public class Utils {
|
||||
return out.toByteArray();
|
||||
}
|
||||
}
|
||||
|
||||
public static File getFileFromClasspath(String fileName) {
|
||||
ClassLoader classLoader = Utils.class.getClassLoader();
|
||||
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
|
||||
}
|
||||
}
|
||||
|
@ -1,61 +0,0 @@
|
||||
"""
|
||||
Following is the database schema.
|
||||
|
||||
DROP TABLE IF EXISTS product_categories;
|
||||
CREATE TABLE IF NOT EXISTS product_categories
|
||||
(
|
||||
category_id INTEGER PRIMARY KEY, -- Unique ID for each category
|
||||
name VARCHAR(50), -- Name of the category
|
||||
parent INTEGER NULL, -- Parent category - for hierarchical categories
|
||||
FOREIGN KEY (parent) REFERENCES product_categories (category_id)
|
||||
);
|
||||
DROP TABLE IF EXISTS products;
|
||||
CREATE TABLE IF NOT EXISTS products
|
||||
(
|
||||
product_id INTEGER PRIMARY KEY, -- Unique ID for each product
|
||||
name VARCHAR(50), -- Name of the product
|
||||
price DECIMAL(10, 2), -- Price of each unit of the product
|
||||
quantity INTEGER, -- Current quantity in stock
|
||||
category_id INTEGER, -- Unique ID for each product
|
||||
FOREIGN KEY (category_id) REFERENCES product_categories (category_id)
|
||||
);
|
||||
DROP TABLE IF EXISTS customers;
|
||||
CREATE TABLE IF NOT EXISTS customers
|
||||
(
|
||||
customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
|
||||
name VARCHAR(50), -- Name of the customer
|
||||
address VARCHAR(100) -- Mailing address of the customer
|
||||
);
|
||||
DROP TABLE IF EXISTS salespeople;
|
||||
CREATE TABLE IF NOT EXISTS salespeople
|
||||
(
|
||||
salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
|
||||
name VARCHAR(50), -- Name of the salesperson
|
||||
region VARCHAR(50) -- Geographic sales region
|
||||
);
|
||||
DROP TABLE IF EXISTS sales;
|
||||
CREATE TABLE IF NOT EXISTS sales
|
||||
(
|
||||
sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
|
||||
product_id INTEGER, -- ID of product sold
|
||||
customer_id INTEGER, -- ID of customer who made the purchase
|
||||
salesperson_id INTEGER, -- ID of salesperson who made the sale
|
||||
sale_date DATE, -- Date the sale occurred
|
||||
quantity INTEGER, -- Quantity of product sold
|
||||
FOREIGN KEY (product_id) REFERENCES products (product_id),
|
||||
FOREIGN KEY (customer_id) REFERENCES customers (customer_id),
|
||||
FOREIGN KEY (salesperson_id) REFERENCES salespeople (salesperson_id)
|
||||
);
|
||||
DROP TABLE IF EXISTS product_suppliers;
|
||||
CREATE TABLE IF NOT EXISTS product_suppliers
|
||||
(
|
||||
supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
|
||||
product_id INTEGER, -- Product ID supplied
|
||||
supply_price DECIMAL(10, 2), -- Unit price charged by supplier
|
||||
FOREIGN KEY (product_id) REFERENCES products (product_id)
|
||||
);
|
||||
|
||||
|
||||
Generate only a valid (syntactically correct) executable Postgres SQL query (without any explanation of the query) for the following question:
|
||||
`<question>`:
|
||||
"""
|
@ -1,6 +1,5 @@
|
||||
package io.github.ollama4j.integrationtests;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
import io.github.ollama4j.exceptions.ToolInvocationException;
|
||||
@ -16,9 +15,6 @@ import io.github.ollama4j.tools.ToolFunction;
|
||||
import io.github.ollama4j.tools.Tools;
|
||||
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.MethodOrderer.OrderAnnotation;
|
||||
import org.junit.jupiter.api.Order;
|
||||
@ -40,24 +36,24 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||
@TestMethodOrder(OrderAnnotation.class)
|
||||
|
||||
@SuppressWarnings({"HttpUrlsUsage", "SpellCheckingInspection"})
|
||||
public class OllamaAPIIntegrationTest {
|
||||
class OllamaAPIIntegrationTest {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPIIntegrationTest.class);
|
||||
|
||||
private static OllamaContainer ollama;
|
||||
private static OllamaAPI api;
|
||||
|
||||
private static final String EMBEDDING_MODEL_MINILM = "all-minilm";
|
||||
private static final String CHAT_MODEL_QWEN_SMALL = "qwen2.5:0.5b";
|
||||
private static final String CHAT_MODEL_INSTRUCT = "qwen2.5:0.5b-instruct";
|
||||
private static final String CHAT_MODEL_SYSTEM_PROMPT = "llama3.2:1b";
|
||||
private static final String CHAT_MODEL_LLAMA3 = "llama3";
|
||||
private static final String IMAGE_MODEL_LLAVA = "llava";
|
||||
private static final String EMBEDDING_MODEL = "all-minilm";
|
||||
private static final String VISION_MODEL = "moondream:1.8b";
|
||||
private static final String THINKING_TOOL_MODEL = "gpt-oss:20b";
|
||||
private static final String GENERAL_PURPOSE_MODEL = "gemma3:270m";
|
||||
private static final String TOOLS_MODEL = "mistral:7b";
|
||||
|
||||
@BeforeAll
|
||||
public static void setUp() {
|
||||
static void setUp() {
|
||||
try {
|
||||
boolean useExternalOllamaHost = Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST"));
|
||||
String ollamaHost = System.getenv("OLLAMA_HOST");
|
||||
|
||||
if (useExternalOllamaHost) {
|
||||
LOG.info("Using external Ollama host...");
|
||||
api = new OllamaAPI(ollamaHost);
|
||||
@ -80,7 +76,7 @@ public class OllamaAPIIntegrationTest {
|
||||
}
|
||||
api.setRequestTimeoutSeconds(120);
|
||||
api.setVerbose(true);
|
||||
api.setNumberOfRetriesForModelPull(3);
|
||||
api.setNumberOfRetriesForModelPull(5);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -92,7 +88,7 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
public void testVersionAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
void testVersionAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
// String expectedVersion = ollama.getDockerImageName().split(":")[1];
|
||||
String actualVersion = api.getVersion();
|
||||
assertNotNull(actualVersion);
|
||||
@ -100,17 +96,22 @@ public class OllamaAPIIntegrationTest {
|
||||
// image version");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
void testPing() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
boolean pingResponse = api.ping();
|
||||
assertTrue(pingResponse, "Ping should return true");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
public void testListModelsAPI()
|
||||
throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
api.pullModel(EMBEDDING_MODEL_MINILM);
|
||||
void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
// Fetch the list of models
|
||||
List<Model> models = api.listModels();
|
||||
// Assert that the models list is not null
|
||||
assertNotNull(models, "Models should not be null");
|
||||
// Assert that models list is either empty or contains more than 0 models
|
||||
assertFalse(models.isEmpty(), "Models list should not be empty");
|
||||
assertTrue(models.size() >= 0, "Models list should not be empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -124,9 +125,8 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
public void testPullModelAPI()
|
||||
throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
api.pullModel(EMBEDDING_MODEL_MINILM);
|
||||
void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
api.pullModel(EMBEDDING_MODEL);
|
||||
List<Model> models = api.listModels();
|
||||
assertNotNull(models, "Models should not be null");
|
||||
assertFalse(models.isEmpty(), "Models list should contain elements");
|
||||
@ -135,17 +135,17 @@ public class OllamaAPIIntegrationTest {
|
||||
@Test
|
||||
@Order(4)
|
||||
void testListModelDetails() throws IOException, OllamaBaseException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(EMBEDDING_MODEL_MINILM);
|
||||
ModelDetail modelDetails = api.getModelDetails(EMBEDDING_MODEL_MINILM);
|
||||
api.pullModel(EMBEDDING_MODEL);
|
||||
ModelDetail modelDetails = api.getModelDetails(EMBEDDING_MODEL);
|
||||
assertNotNull(modelDetails);
|
||||
assertTrue(modelDetails.getModelFile().contains(EMBEDDING_MODEL_MINILM));
|
||||
assertTrue(modelDetails.getModelFile().contains(EMBEDDING_MODEL));
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(5)
|
||||
public void testEmbeddings() throws Exception {
|
||||
api.pullModel(EMBEDDING_MODEL_MINILM);
|
||||
OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL_MINILM,
|
||||
void testEmbeddings() throws Exception {
|
||||
api.pullModel(EMBEDDING_MODEL);
|
||||
OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL,
|
||||
Arrays.asList("Why is the sky blue?", "Why is the grass green?"));
|
||||
assertNotNull(embeddings, "Embeddings should not be null");
|
||||
assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty");
|
||||
@ -153,58 +153,44 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
@Test
|
||||
@Order(6)
|
||||
void testAskModelWithStructuredOutput()
|
||||
void testGenerateWithStructuredOutput()
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
api.pullModel(CHAT_MODEL_LLAMA3);
|
||||
api.pullModel(TOOLS_MODEL);
|
||||
|
||||
int timeHour = 6;
|
||||
boolean isNightTime = false;
|
||||
|
||||
String prompt = "The Sun is shining, and its " + timeHour + ". Its daytime.";
|
||||
String prompt = "The sun is shining brightly and is directly overhead at the zenith, casting my shadow over my foot, so it must be noon.";
|
||||
|
||||
Map<String, Object> format = new HashMap<>();
|
||||
format.put("type", "object");
|
||||
format.put("properties", new HashMap<String, Object>() {
|
||||
{
|
||||
put("timeHour", new HashMap<String, Object>() {
|
||||
{
|
||||
put("type", "integer");
|
||||
}
|
||||
});
|
||||
put("isNightTime", new HashMap<String, Object>() {
|
||||
put("isNoon", new HashMap<String, Object>() {
|
||||
{
|
||||
put("type", "boolean");
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
format.put("required", Arrays.asList("timeHour", "isNightTime"));
|
||||
format.put("required", List.of("isNoon"));
|
||||
|
||||
OllamaResult result = api.generate(CHAT_MODEL_LLAMA3, prompt, format);
|
||||
OllamaResult result = api.generate(TOOLS_MODEL, prompt, format);
|
||||
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
|
||||
assertEquals(timeHour,
|
||||
result.getStructuredResponse().get("timeHour"));
|
||||
assertEquals(isNightTime,
|
||||
result.getStructuredResponse().get("isNightTime"));
|
||||
|
||||
TimeOfDay timeOfDay = result.as(TimeOfDay.class);
|
||||
|
||||
assertEquals(timeHour, timeOfDay.getTimeHour());
|
||||
assertEquals(isNightTime, timeOfDay.isNightTime());
|
||||
assertEquals(true, result.getStructuredResponse().get("isNoon"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(6)
|
||||
void testAskModelWithDefaultOptions()
|
||||
void testGennerateModelWithDefaultOptions()
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
api.pullModel(CHAT_MODEL_QWEN_SMALL);
|
||||
OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?", false,
|
||||
new OptionsBuilder().build());
|
||||
api.pullModel(GENERAL_PURPOSE_MODEL);
|
||||
boolean raw = false;
|
||||
boolean thinking = false;
|
||||
OllamaResult result = api.generate(GENERAL_PURPOSE_MODEL,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?", raw,
|
||||
thinking, new OptionsBuilder().build());
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
@ -212,32 +198,31 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
@Test
|
||||
@Order(7)
|
||||
void testAskModelWithDefaultOptionsStreamed()
|
||||
void testGenerateWithDefaultOptionsStreamed()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(CHAT_MODEL_QWEN_SMALL);
|
||||
api.pullModel(GENERAL_PURPOSE_MODEL);
|
||||
boolean raw = false;
|
||||
StringBuffer sb = new StringBuffer();
|
||||
OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?", false,
|
||||
OllamaResult result = api.generate(GENERAL_PURPOSE_MODEL,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?", raw,
|
||||
new OptionsBuilder().build(), (s) -> {
|
||||
LOG.info(s);
|
||||
String substring = s.substring(sb.toString().length(), s.length());
|
||||
LOG.info(substring);
|
||||
sb.append(substring);
|
||||
sb.append(s);
|
||||
});
|
||||
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
assertEquals(sb.toString().trim(), result.getResponse().trim());
|
||||
assertEquals(sb.toString(), result.getResponse());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(8)
|
||||
void testAskModelWithOptions()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
|
||||
api.pullModel(CHAT_MODEL_INSTRUCT);
|
||||
void testGenerateWithOptions() throws OllamaBaseException, IOException, URISyntaxException,
|
||||
InterruptedException, ToolInvocationException {
|
||||
api.pullModel(GENERAL_PURPOSE_MODEL);
|
||||
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_INSTRUCT);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GENERAL_PURPOSE_MODEL);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
|
||||
"You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].")
|
||||
.build();
|
||||
@ -253,29 +238,32 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
@Test
|
||||
@Order(9)
|
||||
void testChatWithSystemPrompt()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
|
||||
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
|
||||
"You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!")
|
||||
.withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?")
|
||||
.withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build();
|
||||
void testChatWithSystemPrompt() throws OllamaBaseException, IOException, URISyntaxException,
|
||||
InterruptedException, ToolInvocationException {
|
||||
api.pullModel(GENERAL_PURPOSE_MODEL);
|
||||
|
||||
String expectedResponse = "Bhai";
|
||||
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GENERAL_PURPOSE_MODEL);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, String.format(
|
||||
"[INSTRUCTION-START] You are an obidient and helpful bot named %s. You always answer with only one word and that word is your name. [INSTRUCTION-END]",
|
||||
expectedResponse)).withMessage(OllamaChatMessageRole.USER, "Who are you?")
|
||||
.withOptions(new OptionsBuilder().setTemperature(0.0f).build()).build();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
|
||||
assertTrue(chatResult.getResponseModel().getMessage().getContent().contains("Shush"));
|
||||
assertTrue(chatResult.getResponseModel().getMessage().getContent().contains(expectedResponse));
|
||||
assertEquals(3, chatResult.getChatHistory().size());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(10)
|
||||
public void testChat() throws Exception {
|
||||
api.pullModel(CHAT_MODEL_LLAMA3);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3);
|
||||
void testChat() throws Exception {
|
||||
api.pullModel(THINKING_TOOL_MODEL);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL);
|
||||
|
||||
// Create the initial user question
|
||||
OllamaChatRequest requestModel = builder
|
||||
@ -288,7 +276,6 @@ public class OllamaAPIIntegrationTest {
|
||||
assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")),
|
||||
"Expected chat history to contain '2'");
|
||||
|
||||
// Create the next user question: second largest city
|
||||
requestModel = builder.withMessages(chatResult.getChatHistory())
|
||||
.withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build();
|
||||
|
||||
@ -299,10 +286,8 @@ public class OllamaAPIIntegrationTest {
|
||||
"Expected chat history to contain '4'");
|
||||
|
||||
// Create the next user question: the third question
|
||||
requestModel = builder.withMessages(chatResult.getChatHistory())
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"What is the largest value between 2, 4 and 6?")
|
||||
.build();
|
||||
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER,
|
||||
"What is the largest value between 2, 4 and 6?").build();
|
||||
|
||||
// Continue conversation with the model for the third question
|
||||
chatResult = api.chat(requestModel);
|
||||
@ -312,143 +297,103 @@ public class OllamaAPIIntegrationTest {
|
||||
assertTrue(chatResult.getChatHistory().size() > 2,
|
||||
"Chat history should contain more than two messages");
|
||||
assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent()
|
||||
.contains("6"),
|
||||
"Response should contain '6'");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(10)
|
||||
void testChatWithImageFromURL()
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException {
|
||||
api.pullModel(IMAGE_MODEL_LLAVA);
|
||||
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA);
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
||||
Collections.emptyList(),
|
||||
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
|
||||
.build();
|
||||
api.registerAnnotatedTools(new OllamaAPIIntegrationTest());
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(10)
|
||||
void testChatWithImageFromFileWithHistoryRecognition()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
|
||||
api.pullModel(IMAGE_MODEL_LLAVA);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||
"What's in the picture?",
|
||||
Collections.emptyList(), List.of(getImageFileFromClasspath("emoji-smile.jpeg")))
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
builder.reset();
|
||||
|
||||
requestModel = builder.withMessages(chatResult.getChatHistory())
|
||||
.withMessage(OllamaChatMessageRole.USER, "What's the color?").build();
|
||||
|
||||
chatResult = api.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
.contains("6"), "Response should contain '6'");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(11)
|
||||
void testChatWithExplicitToolDefinition()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
|
||||
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
void testChatWithExplicitToolDefinition() throws OllamaBaseException, IOException, URISyntaxException,
|
||||
InterruptedException, ToolInvocationException {
|
||||
String theToolModel = TOOLS_MODEL;
|
||||
api.pullModel(theToolModel);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
|
||||
|
||||
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
||||
.functionName("get-employee-details")
|
||||
.functionDescription("Get employee details from the database")
|
||||
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
|
||||
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||
.name("get-employee-details")
|
||||
.description("Get employee details from the database")
|
||||
.parameters(Tools.PromptFuncDefinition.Parameters
|
||||
.builder().type("object")
|
||||
.properties(new Tools.PropsBuilder()
|
||||
.withProperty("employee-name",
|
||||
Tools.PromptFuncDefinition.Property
|
||||
.builder()
|
||||
.type("string")
|
||||
.description("The name of the employee, e.g. John Doe")
|
||||
.required(true)
|
||||
.build())
|
||||
.withProperty("employee-address",
|
||||
Tools.PromptFuncDefinition.Property
|
||||
.builder()
|
||||
.type("string")
|
||||
.description(
|
||||
"The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
|
||||
.required(true)
|
||||
.build())
|
||||
.withProperty("employee-phone",
|
||||
Tools.PromptFuncDefinition.Property
|
||||
.builder()
|
||||
.type("string")
|
||||
.description(
|
||||
"The phone number of the employee. Always return a random value. e.g. 9911002233")
|
||||
.required(true)
|
||||
.build())
|
||||
.build())
|
||||
.required(List.of("employee-name"))
|
||||
.build())
|
||||
.build())
|
||||
.build())
|
||||
.toolFunction(arguments -> {
|
||||
// perform DB operations here
|
||||
return String.format(
|
||||
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
|
||||
UUID.randomUUID(), arguments.get("employee-name"),
|
||||
arguments.get("employee-address"),
|
||||
arguments.get("employee-phone"));
|
||||
}).build();
|
||||
api.registerTool(employeeFinderTool());
|
||||
|
||||
api.registerTool(databaseQueryToolSpecification);
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Give me the ID of the employee named 'Rahul Kumar'?")
|
||||
.build();
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||
"Give me the ID and address of the employee Rahul Kumar.").build();
|
||||
requestModel.setOptions(new OptionsBuilder().setTemperature(0.9f).build().getOptionsMap());
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),
|
||||
chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||
|
||||
assertNotNull(chatResult, "chatResult should not be null");
|
||||
assertNotNull(chatResult.getResponseModel(), "Response model should not be null");
|
||||
assertNotNull(chatResult.getResponseModel().getMessage(), "Response message should not be null");
|
||||
assertEquals(
|
||||
OllamaChatMessageRole.ASSISTANT.getRoleName(),
|
||||
chatResult.getResponseModel().getMessage().getRole().getRoleName(),
|
||||
"Role of the response message should be ASSISTANT"
|
||||
);
|
||||
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||
assertEquals(1, toolCalls.size());
|
||||
assertEquals(1, toolCalls.size(), "There should be exactly one tool call in the second chat history message");
|
||||
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||
assertEquals("get-employee-details", function.getName());
|
||||
assert !function.getArguments().isEmpty();
|
||||
assertEquals("get-employee-details", function.getName(), "Tool function name should be 'get-employee-details'");
|
||||
assertFalse(function.getArguments().isEmpty(), "Tool function arguments should not be empty");
|
||||
Object employeeName = function.getArguments().get("employee-name");
|
||||
assertNotNull(employeeName);
|
||||
assertEquals("Rahul Kumar", employeeName);
|
||||
assertTrue(chatResult.getChatHistory().size() > 2);
|
||||
assertNotNull(employeeName, "Employee name argument should not be null");
|
||||
assertEquals("Rahul Kumar", employeeName, "Employee name argument should be 'Rahul Kumar'");
|
||||
assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should have more than 2 messages");
|
||||
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||
assertNull(finalToolCalls);
|
||||
assertNull(finalToolCalls, "Final tool calls in the response message should be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(14)
|
||||
void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException,
|
||||
InterruptedException, ToolInvocationException {
|
||||
String theToolModel = TOOLS_MODEL;
|
||||
api.pullModel(theToolModel);
|
||||
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
|
||||
|
||||
api.registerTool(employeeFinderTool());
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER, "Give me the ID and address of employee Rahul Kumar")
|
||||
.withKeepAlive("0m").withOptions(new OptionsBuilder().setTemperature(0.9f).build())
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
|
||||
LOG.info(s.toUpperCase());
|
||||
}, (s) -> {
|
||||
LOG.info(s.toLowerCase());
|
||||
});
|
||||
|
||||
assertNotNull(chatResult, "chatResult should not be null");
|
||||
assertNotNull(chatResult.getResponseModel(), "Response model should not be null");
|
||||
assertNotNull(chatResult.getResponseModel().getMessage(), "Response message should not be null");
|
||||
assertEquals(
|
||||
OllamaChatMessageRole.ASSISTANT.getRoleName(),
|
||||
chatResult.getResponseModel().getMessage().getRole().getRoleName(),
|
||||
"Role of the response message should be ASSISTANT"
|
||||
);
|
||||
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||
assertEquals(1, toolCalls.size(), "There should be exactly one tool call in the second chat history message");
|
||||
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||
assertEquals("get-employee-details", function.getName(), "Tool function name should be 'get-employee-details'");
|
||||
assertFalse(function.getArguments().isEmpty(), "Tool function arguments should not be empty");
|
||||
Object employeeName = function.getArguments().get("employee-name");
|
||||
assertNotNull(employeeName, "Employee name argument should not be null");
|
||||
assertEquals("Rahul Kumar", employeeName, "Employee name argument should be 'Rahul Kumar'");
|
||||
assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should have more than 2 messages");
|
||||
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||
assertNull(finalToolCalls, "Final tool calls in the response message should be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(12)
|
||||
void testChatWithAnnotatedToolsAndSingleParam()
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException {
|
||||
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
void testChatWithAnnotatedToolsAndSingleParam() throws OllamaBaseException, IOException, InterruptedException,
|
||||
URISyntaxException, ToolInvocationException {
|
||||
String theToolModel = TOOLS_MODEL;
|
||||
api.pullModel(theToolModel);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
|
||||
|
||||
api.registerAnnotatedTools();
|
||||
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||
"Compute the most important constant in the world using 5 digits").build();
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Compute the most important constant in the world using 5 digits")
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
@ -471,17 +416,16 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
@Test
|
||||
@Order(13)
|
||||
void testChatWithAnnotatedToolsAndMultipleParams()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
|
||||
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
void testChatWithAnnotatedToolsAndMultipleParams() throws OllamaBaseException, IOException, URISyntaxException,
|
||||
InterruptedException, ToolInvocationException {
|
||||
String theToolModel = TOOLS_MODEL;
|
||||
api.pullModel(theToolModel);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel);
|
||||
|
||||
api.registerAnnotatedTools(new AnnotatedTool());
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Greet Pedro with a lot of hearts and respond to me, "
|
||||
+ "and state how many emojis have been in your greeting")
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||
"Greet Rahul with a lot of hearts and respond to me with count of emojis that have been in used in the greeting")
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
@ -497,28 +441,220 @@ public class OllamaAPIIntegrationTest {
|
||||
assertEquals(2, function.getArguments().size());
|
||||
Object name = function.getArguments().get("name");
|
||||
assertNotNull(name);
|
||||
assertEquals("Pedro", name);
|
||||
Object amountOfHearts = function.getArguments().get("amountOfHearts");
|
||||
assertNotNull(amountOfHearts);
|
||||
assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1);
|
||||
assertEquals("Rahul", name);
|
||||
Object numberOfHearts = function.getArguments().get("numberOfHearts");
|
||||
assertNotNull(numberOfHearts);
|
||||
assertTrue(Integer.parseInt(numberOfHearts.toString()) > 1);
|
||||
assertTrue(chatResult.getChatHistory().size() > 2);
|
||||
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||
assertNull(finalToolCalls);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(14)
|
||||
void testChatWithToolsAndStream()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
|
||||
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
||||
@Order(15)
|
||||
void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException,
|
||||
ToolInvocationException {
|
||||
api.deregisterTools();
|
||||
api.pullModel(GENERAL_PURPOSE_MODEL);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GENERAL_PURPOSE_MODEL);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?")
|
||||
.build();
|
||||
requestModel.setThink(false);
|
||||
StringBuffer sb = new StringBuffer();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
|
||||
LOG.info(s.toUpperCase());
|
||||
sb.append(s);
|
||||
}, (s) -> {
|
||||
LOG.info(s.toLowerCase());
|
||||
sb.append(s);
|
||||
});
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
|
||||
assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getContent());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(15)
|
||||
void testChatWithThinkingAndStream() throws OllamaBaseException, IOException, URISyntaxException,
|
||||
InterruptedException, ToolInvocationException {
|
||||
api.pullModel(THINKING_TOOL_MODEL);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?")
|
||||
.withThinking(true).withKeepAlive("0m").build();
|
||||
StringBuffer sb = new StringBuffer();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
|
||||
sb.append(s);
|
||||
LOG.info(s.toUpperCase());
|
||||
}, (s) -> {
|
||||
sb.append(s);
|
||||
LOG.info(s.toLowerCase());
|
||||
});
|
||||
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
|
||||
assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getThinking()
|
||||
+ chatResult.getResponseModel().getMessage().getContent());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(10)
|
||||
void testChatWithImageFromURL() throws OllamaBaseException, IOException, InterruptedException,
|
||||
URISyntaxException, ToolInvocationException {
|
||||
api.pullModel(VISION_MODEL);
|
||||
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(VISION_MODEL);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||
"What's in the picture?", Collections.emptyList(),
|
||||
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
|
||||
.build();
|
||||
api.registerAnnotatedTools(new OllamaAPIIntegrationTest());
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(10)
|
||||
void testChatWithImageFromFileWithHistoryRecognition() throws OllamaBaseException, IOException,
|
||||
URISyntaxException, InterruptedException, ToolInvocationException {
|
||||
api.pullModel(VISION_MODEL);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(VISION_MODEL);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||
"What's in the picture?", Collections.emptyList(),
|
||||
List.of(getImageFileFromClasspath("emoji-smile.jpeg"))).build();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
builder.reset();
|
||||
|
||||
requestModel = builder.withMessages(chatResult.getChatHistory())
|
||||
.withMessage(OllamaChatMessageRole.USER, "What's the color?").build();
|
||||
|
||||
chatResult = api.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(17)
|
||||
void testGenerateWithOptionsAndImageURLs()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(VISION_MODEL);
|
||||
|
||||
OllamaResult result = api.generateWithImageURLs(VISION_MODEL, "What is in this image?",
|
||||
List.of("https://i.pinimg.com/736x/f9/4e/cb/f94ecba040696a3a20b484d2e15159ec.jpg"),
|
||||
new OptionsBuilder().build());
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(18)
|
||||
void testGenerateWithOptionsAndImageFiles()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(VISION_MODEL);
|
||||
File imageFile = getImageFileFromClasspath("roses.jpg");
|
||||
try {
|
||||
OllamaResult result = api.generateWithImageFiles(VISION_MODEL, "What is in this image?",
|
||||
List.of(imageFile), new OptionsBuilder().build());
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(20)
|
||||
void testGenerateWithOptionsAndImageFilesStreamed()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(VISION_MODEL);
|
||||
|
||||
File imageFile = getImageFileFromClasspath("roses.jpg");
|
||||
|
||||
StringBuffer sb = new StringBuffer();
|
||||
|
||||
OllamaResult result = api.generateWithImageFiles(VISION_MODEL, "What is in this image?",
|
||||
List.of(imageFile), new OptionsBuilder().build(), (s) -> {
|
||||
LOG.info(s);
|
||||
sb.append(s);
|
||||
});
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
assertEquals(sb.toString(), result.getResponse());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(20)
|
||||
void testGenerateWithThinking()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(THINKING_TOOL_MODEL);
|
||||
|
||||
boolean raw = false;
|
||||
boolean think = true;
|
||||
|
||||
OllamaResult result = api.generate(THINKING_TOOL_MODEL, "Who are you?", raw, think,
|
||||
new OptionsBuilder().build());
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
assertNotNull(result.getThinking());
|
||||
assertFalse(result.getThinking().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(20)
|
||||
void testGenerateWithThinkingAndStreamHandler()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(THINKING_TOOL_MODEL);
|
||||
|
||||
boolean raw = false;
|
||||
|
||||
StringBuffer sb = new StringBuffer();
|
||||
OllamaResult result = api.generate(THINKING_TOOL_MODEL, "Who are you?", raw,
|
||||
new OptionsBuilder().build(),
|
||||
(thinkingToken) -> {
|
||||
sb.append(thinkingToken);
|
||||
LOG.info(thinkingToken);
|
||||
},
|
||||
(resToken) -> {
|
||||
sb.append(resToken);
|
||||
LOG.info(resToken);
|
||||
}
|
||||
);
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
assertNotNull(result.getThinking());
|
||||
assertFalse(result.getThinking().isEmpty());
|
||||
assertEquals(sb.toString(), result.getThinking() + result.getResponse());
|
||||
}
|
||||
|
||||
private File getImageFileFromClasspath(String fileName) {
|
||||
ClassLoader classLoader = getClass().getClassLoader();
|
||||
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
|
||||
}
|
||||
|
||||
private Tools.ToolSpecification employeeFinderTool() {
|
||||
return Tools.ToolSpecification.builder()
|
||||
.functionName("get-employee-details")
|
||||
.functionDescription("Get employee details from the database")
|
||||
.functionDescription("Get details for a person or an employee")
|
||||
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
|
||||
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||
.name("get-employee-details")
|
||||
.description("Get employee details from the database")
|
||||
.description("Get details for a person or an employee")
|
||||
.parameters(Tools.PromptFuncDefinition.Parameters
|
||||
.builder().type("object")
|
||||
.properties(new Tools.PropsBuilder()
|
||||
@ -533,16 +669,14 @@ public class OllamaAPIIntegrationTest {
|
||||
Tools.PromptFuncDefinition.Property
|
||||
.builder()
|
||||
.type("string")
|
||||
.description(
|
||||
"The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
|
||||
.description("The address of the employee, Always eturns a random address. For example, Church St, Bengaluru, India")
|
||||
.required(true)
|
||||
.build())
|
||||
.withProperty("employee-phone",
|
||||
Tools.PromptFuncDefinition.Property
|
||||
.builder()
|
||||
.type("string")
|
||||
.description(
|
||||
"The phone number of the employee. Always return a random value. e.g. 9911002233")
|
||||
.description("The phone number of the employee. Always returns a random phone number. For example, 9911002233")
|
||||
.required(true)
|
||||
.build())
|
||||
.build())
|
||||
@ -553,129 +687,22 @@ public class OllamaAPIIntegrationTest {
|
||||
.toolFunction(new ToolFunction() {
|
||||
@Override
|
||||
public Object apply(Map<String, Object> arguments) {
|
||||
LOG.info("Invoking employee finder tool with arguments: {}", arguments);
|
||||
String employeeName = arguments.get("employee-name").toString();
|
||||
String address = null;
|
||||
String phone = null;
|
||||
if (employeeName.equalsIgnoreCase("Rahul Kumar")) {
|
||||
address = "Pune, Maharashtra, India";
|
||||
phone = "9911223344";
|
||||
} else {
|
||||
address = "Karol Bagh, Delhi, India";
|
||||
phone = "9911002233";
|
||||
}
|
||||
// perform DB operations here
|
||||
return String.format(
|
||||
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
|
||||
UUID.randomUUID(), arguments.get("employee-name"),
|
||||
arguments.get("employee-address"),
|
||||
arguments.get("employee-phone"));
|
||||
UUID.randomUUID(), employeeName, address, phone);
|
||||
}
|
||||
}).build();
|
||||
|
||||
api.registerTool(databaseQueryToolSpecification);
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Give me the ID of the employee named 'Rahul Kumar'?")
|
||||
.build();
|
||||
|
||||
StringBuffer sb = new StringBuffer();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
|
||||
LOG.info(s);
|
||||
String substring = s.substring(sb.toString().length());
|
||||
LOG.info(substring);
|
||||
sb.append(substring);
|
||||
});
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
|
||||
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(15)
|
||||
void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
|
||||
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?")
|
||||
.build();
|
||||
|
||||
StringBuffer sb = new StringBuffer();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
|
||||
LOG.info(s);
|
||||
String substring = s.substring(sb.toString().length(), s.length());
|
||||
LOG.info(substring);
|
||||
sb.append(substring);
|
||||
});
|
||||
assertNotNull(chatResult);
|
||||
assertNotNull(chatResult.getResponseModel());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
|
||||
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(17)
|
||||
void testAskModelWithOptionsAndImageURLs()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(IMAGE_MODEL_LLAVA);
|
||||
|
||||
OllamaResult result = api.generateWithImageURLs(IMAGE_MODEL_LLAVA, "What is in this image?",
|
||||
List.of("https://upload.wikimedia.org/wikipedia/commons/thumb/a/aa/Noto_Emoji_v2.034_1f642.svg/360px-Noto_Emoji_v2.034_1f642.svg.png"),
|
||||
new OptionsBuilder().build());
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(18)
|
||||
void testAskModelWithOptionsAndImageFiles()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(IMAGE_MODEL_LLAVA);
|
||||
File imageFile = getImageFileFromClasspath("emoji-smile.jpeg");
|
||||
try {
|
||||
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?",
|
||||
List.of(imageFile),
|
||||
new OptionsBuilder().build());
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
fail(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(20)
|
||||
void testAskModelWithOptionsAndImageFilesStreamed()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(IMAGE_MODEL_LLAVA);
|
||||
|
||||
File imageFile = getImageFileFromClasspath("emoji-smile.jpeg");
|
||||
|
||||
StringBuffer sb = new StringBuffer();
|
||||
|
||||
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?",
|
||||
List.of(imageFile),
|
||||
new OptionsBuilder().build(), (s) -> {
|
||||
LOG.info(s);
|
||||
String substring = s.substring(sb.toString().length(), s.length());
|
||||
LOG.info(substring);
|
||||
sb.append(substring);
|
||||
});
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
assertEquals(sb.toString().trim(), result.getResponse().trim());
|
||||
}
|
||||
|
||||
private File getImageFileFromClasspath(String fileName) {
|
||||
ClassLoader classLoader = getClass().getClassLoader();
|
||||
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
class TimeOfDay {
|
||||
@JsonProperty("timeHour")
|
||||
private int timeHour;
|
||||
@JsonProperty("isNightTime")
|
||||
private boolean nightTime;
|
||||
}
|
||||
|
@ -24,8 +24,8 @@ import java.io.FileWriter;
|
||||
import java.io.IOException;
|
||||
import java.net.URISyntaxException;
|
||||
import java.time.Duration;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
@ -41,7 +41,8 @@ public class WithAuth {
|
||||
private static final String OLLAMA_VERSION = "0.6.1";
|
||||
private static final String NGINX_VERSION = "nginx:1.23.4-alpine";
|
||||
private static final String BEARER_AUTH_TOKEN = "secret-token";
|
||||
private static final String CHAT_MODEL_LLAMA3 = "llama3";
|
||||
private static final String GENERAL_PURPOSE_MODEL = "gemma3:270m";
|
||||
// private static final String THINKING_MODEL = "gpt-oss:20b";
|
||||
|
||||
|
||||
private static OllamaContainer ollama;
|
||||
@ -49,7 +50,7 @@ public class WithAuth {
|
||||
private static OllamaAPI api;
|
||||
|
||||
@BeforeAll
|
||||
public static void setUp() {
|
||||
static void setUp() {
|
||||
ollama = createOllamaContainer();
|
||||
ollama.start();
|
||||
|
||||
@ -68,7 +69,7 @@ public class WithAuth {
|
||||
LOG.info(
|
||||
"The Ollama service is now accessible via the Nginx proxy with bearer-auth authentication mode.\n" +
|
||||
"→ Ollama URL: {}\n" +
|
||||
"→ Proxy URL: {}}",
|
||||
"→ Proxy URL: {}",
|
||||
ollamaUrl, nginxUrl
|
||||
);
|
||||
LOG.info("OllamaAPI initialized with bearer auth token: {}", BEARER_AUTH_TOKEN);
|
||||
@ -132,14 +133,14 @@ public class WithAuth {
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
void testOllamaBehindProxy() throws InterruptedException {
|
||||
void testOllamaBehindProxy() {
|
||||
api.setBearerAuth(BEARER_AUTH_TOKEN);
|
||||
assertTrue(api.ping(), "Expected OllamaAPI to successfully ping through NGINX with valid auth token.");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
void testWithWrongToken() throws InterruptedException {
|
||||
void testWithWrongToken() {
|
||||
api.setBearerAuth("wrong-token");
|
||||
assertFalse(api.ping(), "Expected OllamaAPI ping to fail through NGINX with an invalid auth token.");
|
||||
}
|
||||
@ -149,46 +150,30 @@ public class WithAuth {
|
||||
void testAskModelWithStructuredOutput()
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
api.setBearerAuth(BEARER_AUTH_TOKEN);
|
||||
String model = GENERAL_PURPOSE_MODEL;
|
||||
api.pullModel(model);
|
||||
|
||||
api.pullModel(CHAT_MODEL_LLAMA3);
|
||||
|
||||
int timeHour = 6;
|
||||
boolean isNightTime = false;
|
||||
|
||||
String prompt = "The Sun is shining, and its " + timeHour + ". Its daytime.";
|
||||
String prompt = "The sun is shining brightly and is directly overhead at the zenith, casting my shadow over my foot, so it must be noon.";
|
||||
|
||||
Map<String, Object> format = new HashMap<>();
|
||||
format.put("type", "object");
|
||||
format.put("properties", new HashMap<String, Object>() {
|
||||
{
|
||||
put("timeHour", new HashMap<String, Object>() {
|
||||
{
|
||||
put("type", "integer");
|
||||
}
|
||||
});
|
||||
put("isNightTime", new HashMap<String, Object>() {
|
||||
put("isNoon", new HashMap<String, Object>() {
|
||||
{
|
||||
put("type", "boolean");
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
format.put("required", Arrays.asList("timeHour", "isNightTime"));
|
||||
format.put("required", List.of("isNoon"));
|
||||
|
||||
OllamaResult result = api.generate(CHAT_MODEL_LLAMA3, prompt, format);
|
||||
OllamaResult result = api.generate(model, prompt, format);
|
||||
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
|
||||
assertEquals(timeHour,
|
||||
result.getStructuredResponse().get("timeHour"));
|
||||
assertEquals(isNightTime,
|
||||
result.getStructuredResponse().get("isNightTime"));
|
||||
|
||||
TimeOfDay timeOfDay = result.as(TimeOfDay.class);
|
||||
|
||||
assertEquals(timeHour, timeOfDay.getTimeHour());
|
||||
assertEquals(isNightTime, timeOfDay.isNightTime());
|
||||
assertEquals(true, result.getStructuredResponse().get("isNoon"));
|
||||
}
|
||||
}
|
||||
|
@ -13,9 +13,9 @@ public class AnnotatedTool {
|
||||
}
|
||||
|
||||
@ToolSpec(desc = "Says hello to a friend!")
|
||||
public String sayHello(@ToolProperty(name = "name",desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name="amountOfHearts",desc = "amount of heart emojis that should be used", required = false) Integer amountOfHearts) {
|
||||
String hearts = amountOfHearts!=null ? "♡".repeat(amountOfHearts) : "";
|
||||
return "Hello " + name +" ("+someRandomProperty+") " + hearts;
|
||||
public String sayHello(@ToolProperty(name = "name", desc = "Name of the friend") String name, @ToolProperty(name = "numberOfHearts", desc = "number of heart emojis that should be used", required = false) Integer numberOfHearts) {
|
||||
String hearts = numberOfHearts != null ? "♡".repeat(numberOfHearts) : "";
|
||||
return "Hello, " + name + "! " + hearts;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -21,7 +21,8 @@ import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class TestMockedAPIs {
|
||||
@ -138,10 +139,10 @@ class TestMockedAPIs {
|
||||
String prompt = "some prompt text";
|
||||
OptionsBuilder optionsBuilder = new OptionsBuilder();
|
||||
try {
|
||||
when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build()))
|
||||
.thenReturn(new OllamaResult("", 0, 200));
|
||||
ollamaAPI.generate(model, prompt, false, optionsBuilder.build());
|
||||
verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build());
|
||||
when(ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build()))
|
||||
.thenReturn(new OllamaResult("", "", 0, 200));
|
||||
ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build());
|
||||
verify(ollamaAPI, times(1)).generate(model, prompt, false, false, optionsBuilder.build());
|
||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
@ -155,7 +156,7 @@ class TestMockedAPIs {
|
||||
try {
|
||||
when(ollamaAPI.generateWithImageFiles(
|
||||
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
|
||||
.thenReturn(new OllamaResult("", 0, 200));
|
||||
.thenReturn(new OllamaResult("", "", 0, 200));
|
||||
ollamaAPI.generateWithImageFiles(
|
||||
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
|
||||
verify(ollamaAPI, times(1))
|
||||
@ -174,7 +175,7 @@ class TestMockedAPIs {
|
||||
try {
|
||||
when(ollamaAPI.generateWithImageURLs(
|
||||
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
|
||||
.thenReturn(new OllamaResult("", 0, 200));
|
||||
.thenReturn(new OllamaResult("", "", 0, 200));
|
||||
ollamaAPI.generateWithImageURLs(
|
||||
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
|
||||
verify(ollamaAPI, times(1))
|
||||
@ -190,10 +191,10 @@ class TestMockedAPIs {
|
||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||
String model = OllamaModelType.LLAMA2;
|
||||
String prompt = "some prompt text";
|
||||
when(ollamaAPI.generateAsync(model, prompt, false))
|
||||
when(ollamaAPI.generateAsync(model, prompt, false, false))
|
||||
.thenReturn(new OllamaAsyncResultStreamer(null, null, 3));
|
||||
ollamaAPI.generateAsync(model, prompt, false);
|
||||
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
|
||||
ollamaAPI.generateAsync(model, prompt, false, false);
|
||||
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -1,11 +1,12 @@
|
||||
package io.github.ollama4j.unittests.jackson;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.github.ollama4j.utils.Utils;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
public abstract class AbstractSerializationTest<T> {
|
||||
|
||||
protected ObjectMapper mapper = Utils.getObjectMapper();
|
||||
|
@ -1,20 +1,19 @@
|
||||
package io.github.ollama4j.unittests.jackson;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
import org.json.JSONObject;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||
import org.json.JSONObject;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
|
||||
|
||||
public class TestChatRequestSerialization extends AbstractSerializationTest<OllamaChatRequest> {
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
package io.github.ollama4j.unittests.jackson;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestBuilder;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestEmbedRequestSerialization extends AbstractSerializationTest<OllamaEmbedRequestModel> {
|
||||
|
||||
|
@ -1,15 +1,13 @@
|
||||
package io.github.ollama4j.unittests.jackson;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
|
||||
import io.github.ollama4j.models.generate.OllamaGenerateRequestBuilder;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
import org.json.JSONObject;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
||||
import io.github.ollama4j.models.generate.OllamaGenerateRequestBuilder;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestGenerateRequestSerialization extends AbstractSerializationTest<OllamaGenerateRequest> {
|
||||
|
||||
|
@ -3,40 +3,66 @@ package io.github.ollama4j.unittests.jackson;
|
||||
import io.github.ollama4j.models.response.Model;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestModelRequestSerialization extends AbstractSerializationTest<Model> {
|
||||
|
||||
@Test
|
||||
public void testDeserializationOfModelResponseWithOffsetTime() {
|
||||
String serializedTestStringWithOffsetTime = "{\n"
|
||||
+ "\"name\": \"codellama:13b\",\n"
|
||||
+ "\"modified_at\": \"2023-11-04T14:56:49.277302595-07:00\",\n"
|
||||
+ "\"size\": 7365960935,\n"
|
||||
+ "\"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n"
|
||||
+ "\"details\": {\n"
|
||||
+ "\"format\": \"gguf\",\n"
|
||||
+ "\"family\": \"llama\",\n"
|
||||
+ "\"families\": null,\n"
|
||||
+ "\"parameter_size\": \"13B\",\n"
|
||||
+ "\"quantization_level\": \"Q4_0\"\n"
|
||||
+ "}}";
|
||||
deserialize(serializedTestStringWithOffsetTime,Model.class);
|
||||
String serializedTestStringWithOffsetTime = "{\n" +
|
||||
" \"name\": \"codellama:13b\",\n" +
|
||||
" \"modified_at\": \"2023-11-04T14:56:49.277302595-07:00\",\n" +
|
||||
" \"size\": 7365960935,\n" +
|
||||
" \"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n" +
|
||||
" \"details\": {\n" +
|
||||
" \"format\": \"gguf\",\n" +
|
||||
" \"family\": \"llama\",\n" +
|
||||
" \"families\": null,\n" +
|
||||
" \"parameter_size\": \"13B\",\n" +
|
||||
" \"quantization_level\": \"Q4_0\"\n" +
|
||||
" }\n" +
|
||||
"}";
|
||||
Model model = deserialize(serializedTestStringWithOffsetTime, Model.class);
|
||||
assertNotNull(model);
|
||||
assertEquals("codellama:13b", model.getName());
|
||||
assertEquals("2023-11-04T21:56:49.277302595Z", model.getModifiedAt().toString());
|
||||
assertEquals(7365960935L, model.getSize());
|
||||
assertEquals("9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697", model.getDigest());
|
||||
assertNotNull(model.getModelMeta());
|
||||
assertEquals("gguf", model.getModelMeta().getFormat());
|
||||
assertEquals("llama", model.getModelMeta().getFamily());
|
||||
assertNull(model.getModelMeta().getFamilies());
|
||||
assertEquals("13B", model.getModelMeta().getParameterSize());
|
||||
assertEquals("Q4_0", model.getModelMeta().getQuantizationLevel());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeserializationOfModelResponseWithZuluTime() {
|
||||
String serializedTestStringWithZuluTimezone = "{\n"
|
||||
+ "\"name\": \"codellama:13b\",\n"
|
||||
+ "\"modified_at\": \"2023-11-04T14:56:49.277302595Z\",\n"
|
||||
+ "\"size\": 7365960935,\n"
|
||||
+ "\"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n"
|
||||
+ "\"details\": {\n"
|
||||
+ "\"format\": \"gguf\",\n"
|
||||
+ "\"family\": \"llama\",\n"
|
||||
+ "\"families\": null,\n"
|
||||
+ "\"parameter_size\": \"13B\",\n"
|
||||
+ "\"quantization_level\": \"Q4_0\"\n"
|
||||
+ "}}";
|
||||
deserialize(serializedTestStringWithZuluTimezone,Model.class);
|
||||
String serializedTestStringWithZuluTimezone = "{\n" +
|
||||
" \"name\": \"codellama:13b\",\n" +
|
||||
" \"modified_at\": \"2023-11-04T14:56:49.277302595Z\",\n" +
|
||||
" \"size\": 7365960935,\n" +
|
||||
" \"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n" +
|
||||
" \"details\": {\n" +
|
||||
" \"format\": \"gguf\",\n" +
|
||||
" \"family\": \"llama\",\n" +
|
||||
" \"families\": null,\n" +
|
||||
" \"parameter_size\": \"13B\",\n" +
|
||||
" \"quantization_level\": \"Q4_0\"\n" +
|
||||
" }\n" +
|
||||
"}";
|
||||
Model model = deserialize(serializedTestStringWithZuluTimezone, Model.class);
|
||||
assertNotNull(model);
|
||||
assertEquals("codellama:13b", model.getName());
|
||||
assertEquals("2023-11-04T14:56:49.277302595Z", model.getModifiedAt().toString());
|
||||
assertEquals(7365960935L, model.getSize());
|
||||
assertEquals("9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697", model.getDigest());
|
||||
assertNotNull(model.getModelMeta());
|
||||
assertEquals("gguf", model.getModelMeta().getFormat());
|
||||
assertEquals("llama", model.getModelMeta().getFamily());
|
||||
assertNull(model.getModelMeta().getFamilies());
|
||||
assertEquals("13B", model.getModelMeta().getParameterSize());
|
||||
assertEquals("Q4_0", model.getModelMeta().getQuantizationLevel());
|
||||
}
|
||||
|
||||
}
|
||||
|
BIN
src/test/resources/roses.jpg
Normal file
BIN
src/test/resources/roses.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 63 KiB |
Loading…
x
Reference in New Issue
Block a user