diff --git a/.github/workflows/build-on-pull-request.yml b/.github/workflows/build-on-pull-request.yml index dfa287d..92eb888 100644 --- a/.github/workflows/build-on-pull-request.yml +++ b/.github/workflows/build-on-pull-request.yml @@ -1,20 +1,21 @@ -name: Run Tests +name: Build and Test on Pull Request on: pull_request: - # types: [opened, reopened, synchronize, edited] - branches: [ "main" ] + types: [opened, reopened, synchronize] + branches: + - main paths: - - 'src/**' # Run if changes occur in the 'src' folder - - 'pom.xml' # Run if changes occur in the 'pom.xml' file + - 'src/**' + - 'pom.xml' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true jobs: - run-tests: - + build: + name: Build Java Project runs-on: ubuntu-latest permissions: contents: read @@ -26,18 +27,26 @@ jobs: with: java-version: '11' distribution: 'adopt-hotspot' - server-id: github # Value of the distributionManagement/repository/id field of the pom.xml - settings-path: ${{ github.workspace }} # location for the settings.xml file + server-id: github + settings-path: ${{ github.workspace }} - name: Build with Maven run: mvn --file pom.xml -U clean package - - name: Run unit tests - run: mvn --file pom.xml -U clean test -Punit-tests + run-tests: + name: Run Unit and Integration Tests + needs: build + uses: ./.github/workflows/run-tests.yml + with: + branch: ${{ github.head_ref || github.ref_name }} - - name: Run integration tests - run: mvn --file pom.xml -U clean verify -Pintegration-tests + build-docs: + name: Build Documentation + needs: [build, run-tests] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 - name: Use Node.js uses: actions/setup-node@v3 with: diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 705ff27..ef5a16e 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -1,18 +1,29 @@ -name: Run Unit and Integration Tests +name: Run Tests on: -# push: -# branches: -# - main + # push: + # branches: + # - main + + workflow_call: + inputs: + branch: + description: 'Branch name to run the tests on' + required: true + default: 'main' + type: string + workflow_dispatch: inputs: branch: description: 'Branch name to run the tests on' required: true default: 'main' + type: string jobs: run-tests: + name: Unit and Integration Tests runs-on: ubuntu-latest steps: @@ -21,17 +32,6 @@ jobs: with: ref: ${{ github.event.inputs.branch }} - - name: Use workflow from checked out branch - run: | - if [ -f .github/workflows/run-tests.yml ]; then - echo "Using workflow from checked out branch." - cp .github/workflows/run-tests.yml /tmp/run-tests.yml - exit 0 - else - echo "Workflow file not found in checked out branch." - exit 1 - fi - - name: Set up Ollama run: | curl -fsSL https://ollama.com/install.sh | sh @@ -51,4 +51,4 @@ jobs: run: mvn clean verify -Pintegration-tests env: USE_EXTERNAL_OLLAMA_HOST: "true" - OLLAMA_HOST: "http://localhost:11434" + OLLAMA_HOST: "http://localhost:11434" \ No newline at end of file diff --git a/docs/docs/apis-generate/generate.md b/docs/docs/apis-generate/generate.md index 463b7bb..6f2b1f7 100644 --- a/docs/docs/apis-generate/generate.md +++ b/docs/docs/apis-generate/generate.md @@ -29,7 +29,7 @@ You will get a response similar to: ### Try asking a question, receiving the answer streamed - + You will get a response similar to: diff --git a/pom.xml b/pom.xml index 087ca96..794c5ae 100644 --- a/pom.xml +++ b/pom.xml @@ -14,11 +14,12 @@ 11 - ${git.commit.time} + ${git.commit.time} + UTF-8 3.0.0-M5 3.0.0-M5 - 1.18.30 + 1.18.38 @@ -46,6 +47,19 @@ + + org.apache.maven.plugins + maven-compiler-plugin + + + + org.projectlombok + lombok + ${lombok.version} + + + + org.apache.maven.plugins maven-source-plugin @@ -146,7 +160,7 @@ yyyy-MM-dd'T'HH:mm:ss'Z' - Etc/UTC + Etc/UTC @@ -412,4 +426,4 @@ - + \ No newline at end of file diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 5689faa..65831e1 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -22,6 +22,7 @@ import io.github.ollama4j.tools.*; import io.github.ollama4j.tools.annotations.OllamaToolService; import io.github.ollama4j.tools.annotations.ToolProperty; import io.github.ollama4j.tools.annotations.ToolSpec; +import io.github.ollama4j.utils.Constants; import io.github.ollama4j.utils.Options; import io.github.ollama4j.utils.Utils; import lombok.Setter; @@ -55,33 +56,54 @@ import java.util.stream.Collectors; public class OllamaAPI { private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); + private final String host; + private Auth auth; + private final ToolRegistry toolRegistry = new ToolRegistry(); + /** - * -- SETTER -- - * Set request timeout in seconds. Default is 3 seconds. + * The request timeout in seconds for API calls. + *

+ * Default is 10 seconds. This value determines how long the client will wait + * for a response + * from the Ollama server before timing out. */ @Setter private long requestTimeoutSeconds = 10; + /** - * -- SETTER -- - * Set/unset logging of responses + * Enables or disables verbose logging of responses. + *

+ * If set to {@code true}, the API will log detailed information about requests + * and responses. + * Default is {@code true}. */ @Setter private boolean verbose = true; + /** + * The maximum number of retries for tool calls during chat interactions. + *

+ * This value controls how many times the API will attempt to call a tool in the + * event of a failure. + * Default is 3. + */ @Setter private int maxChatToolCallRetries = 3; - private Auth auth; - + /** + * The number of retries to attempt when pulling a model from the Ollama server. + *

+ * If set to 0, no retries will be performed. If greater than 0, the API will + * retry pulling the model + * up to the specified number of times in case of failure. + *

+ * Default is 0 (no retries). + */ + @Setter + @SuppressWarnings({"FieldMayBeFinal", "FieldCanBeLocal"}) private int numberOfRetriesForModelPull = 0; - public void setNumberOfRetriesForModelPull(int numberOfRetriesForModelPull) { - this.numberOfRetriesForModelPull = numberOfRetriesForModelPull; - } - - private final ToolRegistry toolRegistry = new ToolRegistry(); - /** * Instantiates the Ollama API with default Ollama host: * http://localhost:11434 @@ -102,7 +124,7 @@ public class OllamaAPI { this.host = host; } if (this.verbose) { - logger.info("Ollama API initialized with host: " + this.host); + logger.info("Ollama API initialized with host: {}", this.host); } } @@ -135,14 +157,17 @@ public class OllamaAPI { public boolean ping() { String url = this.host + "/api/tags"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = null; + HttpRequest httpRequest; try { - httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").GET().build(); + httpRequest = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .GET() + .build(); } catch (URISyntaxException e) { throw new RuntimeException(e); } - HttpResponse response = null; + HttpResponse response; try { response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); } catch (HttpConnectTimeoutException e) { @@ -168,8 +193,10 @@ public class OllamaAPI { HttpClient httpClient = HttpClient.newHttpClient(); HttpRequest httpRequest = null; try { - httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").GET().build(); + httpRequest = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .GET().build(); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -196,8 +223,10 @@ public class OllamaAPI { public List listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = this.host + "/api/tags"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET() + .build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -229,8 +258,10 @@ public class OllamaAPI { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = "https://ollama.com/library"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET() + .build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -296,8 +327,10 @@ public class OllamaAPI { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName()); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET() + .build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -338,6 +371,14 @@ public class OllamaAPI { /** * Finds a specific model using model name and tag from Ollama library. *

+ * Deprecated: This method relies on the HTML structure of the Ollama + * website, + * which is subject to change at any time. As a result, it is difficult to keep + * this API + * method consistently updated and reliable. Therefore, this method is + * deprecated and + * may be removed in future releases. + *

* This method retrieves the model from the Ollama library by its name, then * fetches its tags. * It searches through the tags of the model to find one that matches the @@ -355,7 +396,11 @@ public class OllamaAPI { * @throws URISyntaxException If there is an error with the URI syntax. * @throws InterruptedException If the operation is interrupted. * @throws NoSuchElementException If the model or the tag is not found. + * @deprecated This method relies on the HTML structure of the Ollama website, + * which can change at any time and break this API. It is deprecated + * and may be removed in the future. */ + @Deprecated public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { List libraryModels = this.listModelsFromLibrary(); @@ -363,40 +408,71 @@ public class OllamaAPI { .findFirst().orElseThrow( () -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName))); LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel); - LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream() - .filter(tagName -> tagName.getTag().equals(tag)).findFirst() + return libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst() .orElseThrow(() -> new NoSuchElementException( String.format("Tag '%s' for model '%s' not found", tag, modelName))); - return libraryModelTag; } /** * Pull a model on the Ollama server from the list of available models. + *

+ * If {@code numberOfRetriesForModelPull} is greater than 0, this method will + * retry pulling the model + * up to the specified number of times if an {@link OllamaBaseException} occurs, + * using exponential backoff + * between retries (delay doubles after each failed attempt, starting at 1 + * second). + *

+ * The backoff is only applied between retries, not after the final attempt. * * @param modelName the name of the model - * @throws OllamaBaseException if the response indicates an error status + * @throws OllamaBaseException if the response indicates an error status or all + * retries fail * @throws IOException if an I/O error occurs during the HTTP request - * @throws InterruptedException if the operation is interrupted + * @throws InterruptedException if the operation is interrupted or the thread is + * interrupted during backoff * @throws URISyntaxException if the URI for the request is malformed */ public void pullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { if (numberOfRetriesForModelPull == 0) { this.doPullModel(modelName); - } else { - int numberOfRetries = 0; - while (numberOfRetries < numberOfRetriesForModelPull) { - try { - this.doPullModel(modelName); - return; - } catch (OllamaBaseException e) { - logger.error("Failed to pull model " + modelName + ", retrying..."); - numberOfRetries++; - } + return; + } + int numberOfRetries = 0; + long baseDelayMillis = 3000L; // 1 second base delay + while (numberOfRetries < numberOfRetriesForModelPull) { + try { + this.doPullModel(modelName); + return; + } catch (OllamaBaseException e) { + handlePullRetry(modelName, numberOfRetries, numberOfRetriesForModelPull, baseDelayMillis); + numberOfRetries++; } - throw new OllamaBaseException( - "Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries"); + } + throw new OllamaBaseException( + "Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries"); + } + + /** + * Handles retry backoff for pullModel. + */ + private void handlePullRetry(String modelName, int currentRetry, int maxRetries, long baseDelayMillis) + throws InterruptedException { + int attempt = currentRetry + 1; + if (attempt < maxRetries) { + long backoffMillis = baseDelayMillis * (1L << currentRetry); + logger.error("Failed to pull model {}, retrying in {}s... (attempt {}/{})", + modelName, backoffMillis / 1000, attempt, maxRetries); + try { + Thread.sleep(backoffMillis); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw ie; + } + } else { + logger.error("Failed to pull model {} after {} attempts, no more retries.", modelName, maxRetries); } } @@ -404,10 +480,9 @@ public class OllamaAPI { throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String url = this.host + "/api/pull"; String jsonData = new ModelRequest(modelName).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)) - .POST(HttpRequest.BodyPublishers.ofString(jsonData)) - .header("Accept", "application/json") - .header("Content-type", "application/json") + HttpRequest request = getRequestBuilderDefault(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) .build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); @@ -428,7 +503,7 @@ public class OllamaAPI { if (modelPullResponse.getStatus() != null) { if (verbose) { - logger.info(modelName + ": " + modelPullResponse.getStatus()); + logger.info("{}: {}", modelName, modelPullResponse.getStatus()); } // Check if status is "success" and set success flag to true. if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) { @@ -452,8 +527,10 @@ public class OllamaAPI { public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException { String url = this.host + "/api/version"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET() + .build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -498,8 +575,10 @@ public class OllamaAPI { throws IOException, OllamaBaseException, InterruptedException, URISyntaxException { String url = this.host + "/api/show"; String jsonData = new ModelRequest(modelName).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -529,8 +608,9 @@ public class OllamaAPI { throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-Type", "application/json") + HttpRequest request = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); @@ -569,8 +649,9 @@ public class OllamaAPI { throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-Type", "application/json") + HttpRequest request = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); @@ -602,8 +683,9 @@ public class OllamaAPI { throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = customModelRequest.toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-Type", "application/json") + HttpRequest request = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); @@ -637,7 +719,9 @@ public class OllamaAPI { String jsonData = new ModelRequest(modelName).toString(); HttpRequest request = getRequestBuilderDefault(new URI(url)) .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) - .header("Accept", "application/json").header("Content-type", "application/json").build(); + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -683,7 +767,8 @@ public class OllamaAPI { URI uri = URI.create(this.host + "/api/embeddings"); String jsonData = modelRequest.toString(); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json") + HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) .POST(HttpRequest.BodyPublishers.ofString(jsonData)); HttpRequest request = requestBuilder.build(); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); @@ -728,7 +813,8 @@ public class OllamaAPI { String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json") + HttpRequest request = HttpRequest.newBuilder(uri) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); @@ -744,33 +830,112 @@ public class OllamaAPI { /** * Generate response for a question to a model running on Ollama server. This is - * a sync/blocking - * call. + * a sync/blocking call. This API does not support "thinking" models. * - * @param model the ollama model to ask the question to - * @param prompt the prompt/question text - * @param options the Options object - More - * details on the options - * @param streamHandler optional callback consumer that will be applied every - * time a streamed response is received. If not set, the - * stream parameter of the request is set to false. + * @param model the ollama model to ask the question to + * @param prompt the prompt/question text + * @param raw if true no formatting will be applied to the + * prompt. You + * may choose to use the raw parameter if you are + * specifying a full templated prompt in your + * request to + * the API + * @param options the Options object - More + * details on the options + * @param responseStreamHandler optional callback consumer that will be applied + * every + * time a streamed response is received. If not + * set, the + * stream parameter of the request is set to false. * @return OllamaResult that includes response text and time taken for response * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ public OllamaResult generate(String model, String prompt, boolean raw, Options options, - OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + OllamaStreamHandler responseStreamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); ollamaRequestModel.setRaw(raw); + ollamaRequestModel.setThink(false); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); + return generateSyncForOllamaRequestModel(ollamaRequestModel, null, responseStreamHandler); + } + + /** + * Generate thinking and response tokens for a question to a thinking model + * running on Ollama server. This is + * a sync/blocking call. + * + * @param model the ollama model to ask the question to + * @param prompt the prompt/question text + * @param raw if true no formatting will be applied to the + * prompt. You + * may choose to use the raw parameter if you are + * specifying a full templated prompt in your + * request to + * the API + * @param options the Options object - More + * details on the options + * @param responseStreamHandler optional callback consumer that will be applied + * every + * time a streamed response is received. If not + * set, the + * stream parameter of the request is set to false. + * @return OllamaResult that includes response text and time taken for response + * @throws OllamaBaseException if the response indicates an error status + * @throws IOException if an I/O error occurs during the HTTP request + * @throws InterruptedException if the operation is interrupted + */ + public OllamaResult generate(String model, String prompt, boolean raw, Options options, + OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) + throws OllamaBaseException, IOException, InterruptedException { + OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); + ollamaRequestModel.setRaw(raw); + ollamaRequestModel.setThink(true); + ollamaRequestModel.setOptions(options.getOptionsMap()); + return generateSyncForOllamaRequestModel(ollamaRequestModel, thinkingStreamHandler, responseStreamHandler); + } + + /** + * Generates response using the specified AI model and prompt (in blocking + * mode). + *

+ * Uses + * {@link #generate(String, String, boolean, Options, OllamaStreamHandler)} + * + * @param model The name or identifier of the AI model to use for generating + * the response. + * @param prompt The input text or prompt to provide to the AI model. + * @param raw In some cases, you may wish to bypass the templating system + * and provide a full prompt. In this case, you can use the raw + * parameter to disable templating. Also note that raw mode will + * not return a context. + * @param options Additional options or configurations to use when generating + * the response. + * @param think if true the model will "think" step-by-step before + * generating the final response + * @return {@link OllamaResult} + * @throws OllamaBaseException if the response indicates an error status + * @throws IOException if an I/O error occurs during the HTTP request + * @throws InterruptedException if the operation is interrupted + */ + public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options) + throws OllamaBaseException, IOException, InterruptedException { + if (think) { + return generate(model, prompt, raw, options, null, null); + } else { + return generate(model, prompt, raw, options, null); + } } /** * Generates structured output from the specified AI model and prompt. + *

+ * Note: When formatting is specified, the 'think' parameter is not allowed. * * @param model The name or identifier of the AI model to use for generating * the response. @@ -783,6 +948,7 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request. * @throws InterruptedException if the operation is interrupted. */ + @SuppressWarnings("LoggingSimilarMessage") public OllamaResult generate(String model, String prompt, Map format) throws OllamaBaseException, IOException, InterruptedException { URI uri = URI.create(this.host + "/api/generate"); @@ -797,51 +963,52 @@ public class OllamaAPI { HttpClient httpClient = HttpClient.newHttpClient(); HttpRequest request = getRequestBuilderDefault(uri) - .header("Accept", "application/json") - .header("Content-type", "application/json") - .POST(HttpRequest.BodyPublishers.ofString(jsonData)) - .build(); + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + if (verbose) { + try { + String prettyJson = Utils.getObjectMapper().writerWithDefaultPrettyPrinter() + .writeValueAsString(Utils.getObjectMapper().readValue(jsonData, Object.class)); + logger.info("Asking model:\n{}", prettyJson); + } catch (Exception e) { + logger.info("Asking model: {}", jsonData); + } + } HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseBody = response.body(); - if (statusCode == 200) { OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, OllamaStructuredResult.class); - OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), + OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(), structuredResult.getResponseTime(), statusCode); + + ollamaResult.setModel(structuredResult.getModel()); + ollamaResult.setCreatedAt(structuredResult.getCreatedAt()); + ollamaResult.setDone(structuredResult.isDone()); + ollamaResult.setDoneReason(structuredResult.getDoneReason()); + ollamaResult.setContext(structuredResult.getContext()); + ollamaResult.setTotalDuration(structuredResult.getTotalDuration()); + ollamaResult.setLoadDuration(structuredResult.getLoadDuration()); + ollamaResult.setPromptEvalCount(structuredResult.getPromptEvalCount()); + ollamaResult.setPromptEvalDuration(structuredResult.getPromptEvalDuration()); + ollamaResult.setEvalCount(structuredResult.getEvalCount()); + ollamaResult.setEvalDuration(structuredResult.getEvalDuration()); + if (verbose) { + logger.info("Model response:\n{}", ollamaResult); + } return ollamaResult; } else { + if (verbose) { + logger.info("Model response:\n{}", + Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseBody)); + } throw new OllamaBaseException(statusCode + " - " + responseBody); } } - /** - * Generates response using the specified AI model and prompt (in blocking - * mode). - *

- * Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)} - * - * @param model The name or identifier of the AI model to use for generating - * the response. - * @param prompt The input text or prompt to provide to the AI model. - * @param raw In some cases, you may wish to bypass the templating system - * and provide a full prompt. In this case, you can use the raw - * parameter to disable templating. Also note that raw mode will - * not return a context. - * @param options Additional options or configurations to use when generating - * the response. - * @return {@link OllamaResult} - * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request - * @throws InterruptedException if the operation is interrupted - */ - public OllamaResult generate(String model, String prompt, boolean raw, Options options) - throws OllamaBaseException, IOException, InterruptedException { - return generate(model, prompt, raw, options, null); - } - /** * Generates response using the specified AI model and prompt (in blocking * mode), and then invokes a set of tools @@ -893,8 +1060,7 @@ public class OllamaAPI { logger.warn("Response from model does not contain any tool calls. Returning the response as is."); return toolResult; } - toolFunctionCallSpecs = objectMapper.readValue( - toolsResponse, + toolFunctionCallSpecs = objectMapper.readValue(toolsResponse, objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class)); } for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) { @@ -905,19 +1071,47 @@ public class OllamaAPI { } /** - * Generate response for a question to a model running on Ollama server and get - * a callback handle - * that can be used to check for status and get the response from the model - * later. This would be - * an async/non-blocking call. + * Asynchronously generates a response for a prompt using a model running on the + * Ollama server. + *

+ * This method returns an {@link OllamaAsyncResultStreamer} handle that can be + * used to poll for + * status and retrieve streamed "thinking" and response tokens from the model. + * The call is non-blocking. + *

* - * @param model the ollama model to ask the question to - * @param prompt the prompt/question text - * @return the ollama async result callback handle + *

+ * Example usage: + *

+ * + *
{@code
+     * OllamaAsyncResultStreamer resultStreamer = ollamaAPI.generateAsync("gpt-oss:20b", "Who are you", false, true);
+     * int pollIntervalMilliseconds = 1000;
+     * while (true) {
+     *     String thinkingTokens = resultStreamer.getThinkingResponseStream().poll();
+     *     String responseTokens = resultStreamer.getResponseStream().poll();
+     *     System.out.print(thinkingTokens != null ? thinkingTokens.toUpperCase() : "");
+     *     System.out.print(responseTokens != null ? responseTokens.toLowerCase() : "");
+     *     Thread.sleep(pollIntervalMilliseconds);
+     *     if (!resultStreamer.isAlive())
+     *         break;
+     * }
+     * System.out.println("Complete thinking response: " + resultStreamer.getCompleteThinkingResponse());
+     * System.out.println("Complete response: " + resultStreamer.getCompleteResponse());
+     * }
+ * + * @param model the Ollama model to use for generating the response + * @param prompt the prompt or question text to send to the model + * @param raw if {@code true}, returns the raw response from the model + * @param think if {@code true}, streams "thinking" tokens as well as response + * tokens + * @return an {@link OllamaAsyncResultStreamer} handle for polling and + * retrieving streamed results */ - public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw) { + public OllamaAsyncResultStreamer generateAsync(String model, String prompt, boolean raw, boolean think) { OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); ollamaRequestModel.setRaw(raw); + ollamaRequestModel.setThink(think); URI uri = URI.create(this.host + "/api/generate"); OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer( getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds); @@ -953,7 +1147,7 @@ public class OllamaAPI { } OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); + return generateSyncForOllamaRequestModel(ollamaRequestModel, null, streamHandler); } /** @@ -1001,7 +1195,7 @@ public class OllamaAPI { } OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); + return generateSyncForOllamaRequestModel(ollamaRequestModel, null, streamHandler); } /** @@ -1023,38 +1217,47 @@ public class OllamaAPI { /** * Synchronously generates a response using a list of image byte arrays. *

- * This method encodes the provided byte arrays into Base64 and sends them to the Ollama server. + * This method encodes the provided byte arrays into Base64 and sends them to + * the Ollama server. * * @param model the Ollama model to use for generating the response * @param prompt the prompt or question text to send to the model * @param images the list of image data as byte arrays - * @param options the Options object - More details on the options - * @param streamHandler optional callback that will be invoked with each streamed response; if null, streaming is disabled - * @return OllamaResult containing the response text and the time taken for the response + * @param options the Options object - More + * details on the options + * @param streamHandler optional callback that will be invoked with each + * streamed response; if null, streaming is disabled + * @return OllamaResult containing the response text and the time taken for the + * response * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImages(String model, String prompt, List images, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImages(String model, String prompt, List images, Options options, + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { List encodedImages = new ArrayList<>(); for (byte[] image : images) { encodedImages.add(encodeByteArrayToBase64(image)); } OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt, encodedImages); ollamaRequestModel.setOptions(options.getOptionsMap()); - return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); + return generateSyncForOllamaRequestModel(ollamaRequestModel, null, streamHandler); } /** - * Convenience method to call the Ollama API using image byte arrays without streaming responses. + * Convenience method to call the Ollama API using image byte arrays without + * streaming responses. *

- * Uses {@link #generateWithImages(String, String, List, Options, OllamaStreamHandler)} + * Uses + * {@link #generateWithImages(String, String, List, Options, OllamaStreamHandler)} * * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImages(String model, String prompt, List images, Options options) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImages(String model, String prompt, List images, Options options) + throws OllamaBaseException, IOException, InterruptedException { return generateWithImages(model, prompt, images, options, null); } @@ -1069,10 +1272,12 @@ public class OllamaAPI { * history including the newly acquired assistant response. * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network + * @throws InterruptedException in case the server is not reachable or + * network * issues happen * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request + * @throws IOException if an I/O error occurs during the HTTP + * request * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ @@ -1092,16 +1297,18 @@ public class OllamaAPI { * @return {@link OllamaChatResult} * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network + * @throws InterruptedException in case the server is not reachable or + * network * issues happen * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request + * @throws IOException if an I/O error occurs during the HTTP + * request * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ public OllamaChatResult chat(OllamaChatRequest request) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { - return chat(request, null); + return chat(request, null, null); } /** @@ -1110,23 +1317,27 @@ public class OllamaAPI { *

* Hint: the OllamaChatRequestModel#getStream() property is not implemented. * - * @param request request object to be sent to the server - * @param streamHandler callback handler to handle the last message from stream - * (caution: all previous tokens from stream will be - * concatenated) + * @param request request object to be sent to the server + * @param responseStreamHandler callback handler to handle the last message from + * stream + * @param thinkingStreamHandler callback handler to handle the last thinking + * message from stream * @return {@link OllamaChatResult} * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network + * @throws InterruptedException in case the server is not reachable or + * network * issues happen * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request + * @throws IOException if an I/O error occurs during the HTTP + * request * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ - public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) + public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler thinkingStreamHandler, + OllamaStreamHandler responseStreamHandler) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { - return chatStreaming(request, new OllamaChatStreamObserver(streamHandler)); + return chatStreaming(request, new OllamaChatStreamObserver(thinkingStreamHandler, responseStreamHandler)); } /** @@ -1177,8 +1388,11 @@ public class OllamaAPI { } Map arguments = toolCall.getFunction().getArguments(); Object res = toolFunction.apply(arguments); + String argumentKeys = arguments.keySet().stream() + .map(Object::toString) + .collect(Collectors.joining(", ")); request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, - "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]")); + "[TOOL_RESULTS] " + toolName + "(" + argumentKeys + "): " + res + " [/TOOL_RESULTS]")); } if (tokenHandler != null) { @@ -1224,6 +1438,17 @@ public class OllamaAPI { } } + /** + * Deregisters all tools from the tool registry. + * This method removes all registered tools, effectively clearing the registry. + */ + public void deregisterTools() { + toolRegistry.clear(); + if (this.verbose) { + logger.debug("All tools have been deregistered."); + } + } + /** * Registers tools based on the annotations found on the methods of the caller's * class and its providers. @@ -1380,10 +1605,12 @@ public class OllamaAPI { * the request will be streamed; otherwise, a regular synchronous request will * be made. * - * @param ollamaRequestModel the request model containing necessary parameters - * for the Ollama API request. - * @param streamHandler the stream handler to process streaming responses, - * or null for non-streaming requests. + * @param ollamaRequestModel the request model containing necessary + * parameters + * for the Ollama API request. + * @param responseStreamHandler the stream handler to process streaming + * responses, + * or null for non-streaming requests. * @return the result of the Ollama API request. * @throws OllamaBaseException if the request fails due to an issue with the * Ollama API. @@ -1392,13 +1619,14 @@ public class OllamaAPI { * @throws InterruptedException if the thread is interrupted during the request. */ private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, - OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) + throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, verbose); OllamaResult result; - if (streamHandler != null) { + if (responseStreamHandler != null) { ollamaRequestModel.setStream(true); - result = requestCaller.call(ollamaRequestModel, streamHandler); + result = requestCaller.call(ollamaRequestModel, thinkingStreamHandler, responseStreamHandler); } else { result = requestCaller.callSync(ollamaRequestModel); } @@ -1412,7 +1640,8 @@ public class OllamaAPI { * @return HttpRequest.Builder */ private HttpRequest.Builder getRequestBuilderDefault(URI uri) { - HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json") + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) .timeout(Duration.ofSeconds(requestTimeoutSeconds)); if (isBasicAuthCredentialsSet()) { requestBuilder.header("Authorization", auth.getAuthHeaderValue()); diff --git a/src/main/java/io/github/ollama4j/impl/ConsoleOutputStreamHandler.java b/src/main/java/io/github/ollama4j/impl/ConsoleOutputStreamHandler.java index c9f8e36..d990006 100644 --- a/src/main/java/io/github/ollama4j/impl/ConsoleOutputStreamHandler.java +++ b/src/main/java/io/github/ollama4j/impl/ConsoleOutputStreamHandler.java @@ -3,12 +3,8 @@ package io.github.ollama4j.impl; import io.github.ollama4j.models.generate.OllamaStreamHandler; public class ConsoleOutputStreamHandler implements OllamaStreamHandler { - private final StringBuffer response = new StringBuffer(); - @Override public void accept(String message) { - String substr = message.substring(response.length()); - response.append(substr); - System.out.print(substr); + System.out.print(message); } } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java index 86b7726..e3d7912 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java @@ -1,21 +1,15 @@ package io.github.ollama4j.models.chat; -import static io.github.ollama4j.utils.Utils.getObjectMapper; - +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; - import io.github.ollama4j.utils.FileToBase64Serializer; +import lombok.*; import java.util.List; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; -import lombok.NonNull; -import lombok.RequiredArgsConstructor; +import static io.github.ollama4j.utils.Utils.getObjectMapper; /** * Defines a single Message to be used inside a chat request against the ollama /api/chat endpoint. @@ -35,6 +29,8 @@ public class OllamaChatMessage { @NonNull private String content; + private String thinking; + private @JsonProperty("tool_calls") List toolCalls; @JsonSerialize(using = FileToBase64Serializer.class) diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java index 5d19703..7b19e02 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java @@ -1,43 +1,46 @@ package io.github.ollama4j.models.chat; -import java.util.List; - import io.github.ollama4j.models.request.OllamaCommonRequest; import io.github.ollama4j.tools.Tools; import io.github.ollama4j.utils.OllamaRequestBody; - import lombok.Getter; import lombok.Setter; +import java.util.List; + /** * Defines a Request to use against the ollama /api/chat endpoint. * * @see Generate - * Chat Completion + * "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate + * Chat Completion */ @Getter @Setter public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequestBody { - private List messages; + private List messages; - private List tools; + private List tools; - public OllamaChatRequest() {} + private boolean think; - public OllamaChatRequest(String model, List messages) { - this.model = model; - this.messages = messages; - } - - @Override - public boolean equals(Object o) { - if (!(o instanceof OllamaChatRequest)) { - return false; + public OllamaChatRequest() { } - return this.toString().equals(o.toString()); - } + public OllamaChatRequest(String model, boolean think, List messages) { + this.model = model; + this.messages = messages; + this.think = think; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof OllamaChatRequest)) { + return false; + } + + return this.toString().equals(o.toString()); + } } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java index 9094546..4a9caf9 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java @@ -22,7 +22,7 @@ public class OllamaChatRequestBuilder { private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class); private OllamaChatRequestBuilder(String model, List messages) { - request = new OllamaChatRequest(model, messages); + request = new OllamaChatRequest(model, false, messages); } private OllamaChatRequest request; @@ -36,14 +36,20 @@ public class OllamaChatRequestBuilder { } public void reset() { - request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); + request = new OllamaChatRequest(request.getModel(), request.isThink(), new ArrayList<>()); } - public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){ - return withMessage(role,content, Collections.emptyList()); + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content) { + return withMessage(role, content, Collections.emptyList()); } - public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls,List images) { + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls) { + List messages = this.request.getMessages(); + messages.add(new OllamaChatMessage(role, content, null, toolCalls, null)); + return this; + } + + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls, List images) { List messages = this.request.getMessages(); List binaryImages = images.stream().map(file -> { @@ -55,11 +61,11 @@ public class OllamaChatRequestBuilder { } }).collect(Collectors.toList()); - messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages)); + messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages)); return this; } - public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List toolCalls, String... imageUrls) { + public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls, String... imageUrls) { List messages = this.request.getMessages(); List binaryImages = null; if (imageUrls.length > 0) { @@ -75,7 +81,7 @@ public class OllamaChatRequestBuilder { } } - messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages)); + messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages)); return this; } @@ -108,4 +114,8 @@ public class OllamaChatRequestBuilder { return this; } + public OllamaChatRequestBuilder withThinking(boolean think) { + this.request.setThink(think); + return this; + } } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java index f8ebb05..5fbf7e3 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java @@ -1,10 +1,10 @@ package io.github.ollama4j.models.chat; -import java.util.List; - import com.fasterxml.jackson.core.JsonProcessingException; import lombok.Getter; +import java.util.List; + import static io.github.ollama4j.utils.Utils.getObjectMapper; /** diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java index af181da..2ccdb74 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java @@ -6,14 +6,46 @@ import lombok.RequiredArgsConstructor; @RequiredArgsConstructor public class OllamaChatStreamObserver implements OllamaTokenHandler { - private final OllamaStreamHandler streamHandler; + private final OllamaStreamHandler thinkingStreamHandler; + private final OllamaStreamHandler responseStreamHandler; + private String message = ""; @Override public void accept(OllamaChatResponseModel token) { - if (streamHandler != null) { - message += token.getMessage().getContent(); - streamHandler.accept(message); + if (responseStreamHandler == null || token == null || token.getMessage() == null) { + return; + } + + String thinking = token.getMessage().getThinking(); + String content = token.getMessage().getContent(); + + boolean hasThinking = thinking != null && !thinking.isEmpty(); + boolean hasContent = !content.isEmpty(); + +// if (hasThinking && !hasContent) { +//// message += thinking; +// message = thinking; +// } else { +//// message += content; +// message = content; +// } +// +// responseStreamHandler.accept(message); + + + if (!hasContent && hasThinking && thinkingStreamHandler != null) { + // message = message + thinking; + + // use only new tokens received, instead of appending the tokens to the previous + // ones and sending the full string again + thinkingStreamHandler.accept(thinking); + } else if (hasContent && responseStreamHandler != null) { + // message = message + response; + + // use only new tokens received, instead of appending the tokens to the previous + // ones and sending the full string again + responseStreamHandler.accept(content); } } } diff --git a/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingResponseModel.java b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingResponseModel.java index dcf7b47..2d0d90a 100644 --- a/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingResponseModel.java +++ b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingResponseModel.java @@ -1,9 +1,9 @@ package io.github.ollama4j.models.embeddings; import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; import java.util.List; -import lombok.Data; @SuppressWarnings("unused") @Data diff --git a/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingsRequestModel.java b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingsRequestModel.java index d68624c..7d113f0 100644 --- a/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingsRequestModel.java +++ b/src/main/java/io/github/ollama4j/models/embeddings/OllamaEmbeddingsRequestModel.java @@ -1,7 +1,5 @@ package io.github.ollama4j.models.embeddings; -import static io.github.ollama4j.utils.Utils.getObjectMapper; -import java.util.Map; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import lombok.Data; @@ -9,6 +7,10 @@ import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.RequiredArgsConstructor; +import java.util.Map; + +import static io.github.ollama4j.utils.Utils.getObjectMapper; + @Data @RequiredArgsConstructor @NoArgsConstructor diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateRequest.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateRequest.java index de767dc..3763f0a 100644 --- a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateRequest.java +++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateRequest.java @@ -3,12 +3,11 @@ package io.github.ollama4j.models.generate; import io.github.ollama4j.models.request.OllamaCommonRequest; import io.github.ollama4j.utils.OllamaRequestBody; - -import java.util.List; - import lombok.Getter; import lombok.Setter; +import java.util.List; + @Getter @Setter public class OllamaGenerateRequest extends OllamaCommonRequest implements OllamaRequestBody{ @@ -19,6 +18,7 @@ public class OllamaGenerateRequest extends OllamaCommonRequest implements Ollama private String system; private String context; private boolean raw; + private boolean think; public OllamaGenerateRequest() { } diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java index 9fb975e..a3d23ec 100644 --- a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java +++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateResponseModel.java @@ -2,9 +2,9 @@ package io.github.ollama4j.models.generate; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; import java.util.List; -import lombok.Data; @Data @JsonIgnoreProperties(ignoreUnknown = true) @@ -12,12 +12,14 @@ public class OllamaGenerateResponseModel { private String model; private @JsonProperty("created_at") String createdAt; private String response; + private String thinking; private boolean done; + private @JsonProperty("done_reason") String doneReason; private List context; private @JsonProperty("total_duration") Long totalDuration; private @JsonProperty("load_duration") Long loadDuration; - private @JsonProperty("prompt_eval_duration") Long promptEvalDuration; - private @JsonProperty("eval_duration") Long evalDuration; private @JsonProperty("prompt_eval_count") Integer promptEvalCount; + private @JsonProperty("prompt_eval_duration") Long promptEvalDuration; private @JsonProperty("eval_count") Integer evalCount; + private @JsonProperty("eval_duration") Long evalDuration; } diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java index bc47fa0..67ae571 100644 --- a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java +++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java @@ -5,14 +5,16 @@ import java.util.List; public class OllamaGenerateStreamObserver { - private OllamaStreamHandler streamHandler; + private final OllamaStreamHandler thinkingStreamHandler; + private final OllamaStreamHandler responseStreamHandler; - private List responseParts = new ArrayList<>(); + private final List responseParts = new ArrayList<>(); private String message = ""; - public OllamaGenerateStreamObserver(OllamaStreamHandler streamHandler) { - this.streamHandler = streamHandler; + public OllamaGenerateStreamObserver(OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) { + this.responseStreamHandler = responseStreamHandler; + this.thinkingStreamHandler = thinkingStreamHandler; } public void notify(OllamaGenerateResponseModel currentResponsePart) { @@ -21,9 +23,24 @@ public class OllamaGenerateStreamObserver { } protected void handleCurrentResponsePart(OllamaGenerateResponseModel currentResponsePart) { - message = message + currentResponsePart.getResponse(); - streamHandler.accept(message); + String response = currentResponsePart.getResponse(); + String thinking = currentResponsePart.getThinking(); + + boolean hasResponse = response != null && !response.isEmpty(); + boolean hasThinking = thinking != null && !thinking.isEmpty(); + + if (!hasResponse && hasThinking && thinkingStreamHandler != null) { + // message = message + thinking; + + // use only new tokens received, instead of appending the tokens to the previous + // ones and sending the full string again + thinkingStreamHandler.accept(thinking); + } else if (hasResponse && responseStreamHandler != null) { + // message = message + response; + + // use only new tokens received, instead of appending the tokens to the previous + // ones and sending the full string again + responseStreamHandler.accept(response); + } } - - } diff --git a/src/main/java/io/github/ollama4j/models/request/BasicAuth.java b/src/main/java/io/github/ollama4j/models/request/BasicAuth.java index c58b240..13f6a59 100644 --- a/src/main/java/io/github/ollama4j/models/request/BasicAuth.java +++ b/src/main/java/io/github/ollama4j/models/request/BasicAuth.java @@ -1,13 +1,14 @@ package io.github.ollama4j.models.request; -import java.util.Base64; - import lombok.AllArgsConstructor; import lombok.Data; -import lombok.NoArgsConstructor; +import lombok.EqualsAndHashCode; + +import java.util.Base64; @Data @AllArgsConstructor +@EqualsAndHashCode(callSuper = false) public class BasicAuth extends Auth { private String username; private String password; diff --git a/src/main/java/io/github/ollama4j/models/request/BearerAuth.java b/src/main/java/io/github/ollama4j/models/request/BearerAuth.java index 8236042..4d876f2 100644 --- a/src/main/java/io/github/ollama4j/models/request/BearerAuth.java +++ b/src/main/java/io/github/ollama4j/models/request/BearerAuth.java @@ -2,18 +2,20 @@ package io.github.ollama4j.models.request; import lombok.AllArgsConstructor; import lombok.Data; +import lombok.EqualsAndHashCode; @Data @AllArgsConstructor +@EqualsAndHashCode(callSuper = false) public class BearerAuth extends Auth { - private String bearerToken; + private String bearerToken; - /** - * Get authentication header value. - * - * @return authentication header value with bearer token - */ - public String getAuthHeaderValue() { - return "Bearer "+ bearerToken; - } + /** + * Get authentication header value. + * + * @return authentication header value with bearer token + */ + public String getAuthHeaderValue() { + return "Bearer " + bearerToken; + } } diff --git a/src/main/java/io/github/ollama4j/models/request/CustomModelFileContentsRequest.java b/src/main/java/io/github/ollama4j/models/request/CustomModelFileContentsRequest.java index 6841476..52bc684 100644 --- a/src/main/java/io/github/ollama4j/models/request/CustomModelFileContentsRequest.java +++ b/src/main/java/io/github/ollama4j/models/request/CustomModelFileContentsRequest.java @@ -1,11 +1,11 @@ package io.github.ollama4j.models.request; -import static io.github.ollama4j.utils.Utils.getObjectMapper; - import com.fasterxml.jackson.core.JsonProcessingException; import lombok.AllArgsConstructor; import lombok.Data; +import static io.github.ollama4j.utils.Utils.getObjectMapper; + @Data @AllArgsConstructor public class CustomModelFileContentsRequest { diff --git a/src/main/java/io/github/ollama4j/models/request/CustomModelFilePathRequest.java b/src/main/java/io/github/ollama4j/models/request/CustomModelFilePathRequest.java index 2fcda43..578e1c0 100644 --- a/src/main/java/io/github/ollama4j/models/request/CustomModelFilePathRequest.java +++ b/src/main/java/io/github/ollama4j/models/request/CustomModelFilePathRequest.java @@ -1,11 +1,11 @@ package io.github.ollama4j.models.request; -import static io.github.ollama4j.utils.Utils.getObjectMapper; - import com.fasterxml.jackson.core.JsonProcessingException; import lombok.AllArgsConstructor; import lombok.Data; +import static io.github.ollama4j.utils.Utils.getObjectMapper; + @Data @AllArgsConstructor public class CustomModelFilePathRequest { diff --git a/src/main/java/io/github/ollama4j/models/request/CustomModelRequest.java b/src/main/java/io/github/ollama4j/models/request/CustomModelRequest.java index 15725f0..b2ecb91 100644 --- a/src/main/java/io/github/ollama4j/models/request/CustomModelRequest.java +++ b/src/main/java/io/github/ollama4j/models/request/CustomModelRequest.java @@ -1,17 +1,15 @@ package io.github.ollama4j.models.request; -import static io.github.ollama4j.utils.Utils.getObjectMapper; - import com.fasterxml.jackson.core.JsonProcessingException; import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.Data; -import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.Data; import java.util.List; import java.util.Map; +import static io.github.ollama4j.utils.Utils.getObjectMapper; + @Data @AllArgsConstructor diff --git a/src/main/java/io/github/ollama4j/models/request/ModelRequest.java b/src/main/java/io/github/ollama4j/models/request/ModelRequest.java index 923cd87..eca4d41 100644 --- a/src/main/java/io/github/ollama4j/models/request/ModelRequest.java +++ b/src/main/java/io/github/ollama4j/models/request/ModelRequest.java @@ -1,11 +1,11 @@ package io.github.ollama4j.models.request; -import static io.github.ollama4j.utils.Utils.getObjectMapper; - import com.fasterxml.jackson.core.JsonProcessingException; import lombok.AllArgsConstructor; import lombok.Data; +import static io.github.ollama4j.utils.Utils.getObjectMapper; + @Data @AllArgsConstructor public class ModelRequest { diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java index 09a3870..724e028 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java @@ -24,6 +24,7 @@ import java.util.List; /** * Specialization class for requests */ +@SuppressWarnings("resource") public class OllamaChatEndpointCaller extends OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); @@ -46,19 +47,24 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { * in case the JSON Object cannot be parsed to a {@link OllamaChatResponseModel}. Thus, the ResponseModel should * never be null. * - * @param line streamed line of ollama stream response + * @param line streamed line of ollama stream response * @param responseBuffer Stringbuffer to add latest response message part to * @return TRUE, if ollama-Response has 'done' state */ @Override - protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { + protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) { try { OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); // it seems that under heavy load ollama responds with an empty chat message part in the streamed response // thus, we null check the message and hope that the next streamed response has some message content again OllamaChatMessage message = ollamaResponseModel.getMessage(); - if(message != null) { - responseBuffer.append(message.getContent()); + if (message != null) { + if (message.getThinking() != null) { + thinkingBuffer.append(message.getThinking()); + } + else { + responseBuffer.append(message.getContent()); + } if (tokenHandler != null) { tokenHandler.accept(ollamaResponseModel); } @@ -85,13 +91,14 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { .POST( body.getBodyPublisher()); HttpRequest request = requestBuilder.build(); - if (isVerbose()) LOG.info("Asking model: " + body); + if (isVerbose()) LOG.info("Asking model: {}", body); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); InputStream responseBodyStream = response.body(); StringBuilder responseBuffer = new StringBuilder(); + StringBuilder thinkingBuffer = new StringBuilder(); OllamaChatResponseModel ollamaChatResponseModel = null; List wantedToolsForStream = null; try (BufferedReader reader = @@ -115,14 +122,20 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); responseBuffer.append(ollamaResponseModel.getError()); + } else if (statusCode == 500) { + LOG.warn("Status code: 500 (Internal Server Error)"); + OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, + OllamaErrorResponse.class); + responseBuffer.append(ollamaResponseModel.getError()); } else { - boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); - ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); - if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){ + boolean finished = parseResponseAndAddToBuffer(line, responseBuffer, thinkingBuffer); + ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); + if (body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null) { wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls(); } if (finished && body.stream) { ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString()); + ollamaChatResponseModel.getMessage().setThinking(thinkingBuffer.toString()); break; } } @@ -132,11 +145,11 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { LOG.error("Status code " + statusCode); throw new OllamaBaseException(responseBuffer.toString()); } else { - if(wantedToolsForStream != null) { + if (wantedToolsForStream != null) { ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream); } OllamaChatResult ollamaResult = - new OllamaChatResult(ollamaChatResponseModel,body.getMessages()); + new OllamaChatResult(ollamaChatResponseModel, body.getMessages()); if (isVerbose()) LOG.info("Model response: " + ollamaResult); return ollamaResult; } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaCommonRequest.java b/src/main/java/io/github/ollama4j/models/request/OllamaCommonRequest.java index 0ab6cbc..879d801 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaCommonRequest.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaCommonRequest.java @@ -1,15 +1,15 @@ package io.github.ollama4j.models.request; -import java.util.Map; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.annotation.JsonSerialize; - import io.github.ollama4j.utils.BooleanToJsonFormatFlagSerializer; import io.github.ollama4j.utils.Utils; import lombok.Data; +import java.util.Map; + @Data @JsonInclude(JsonInclude.Include.NON_NULL) public abstract class OllamaCommonRequest { diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java index 1f42ef8..c7bdba0 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java @@ -1,15 +1,15 @@ package io.github.ollama4j.models.request; +import io.github.ollama4j.OllamaAPI; +import io.github.ollama4j.utils.Constants; +import lombok.Getter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.net.URI; import java.net.http.HttpRequest; import java.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.github.ollama4j.OllamaAPI; -import lombok.Getter; - /** * Abstract helperclass to call the ollama api server. */ @@ -32,7 +32,7 @@ public abstract class OllamaEndpointCaller { protected abstract String getEndpointSuffix(); - protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); + protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer); /** @@ -44,7 +44,7 @@ public abstract class OllamaEndpointCaller { protected HttpRequest.Builder getRequestBuilderDefault(URI uri) { HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri) - .header("Content-Type", "application/json") + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) .timeout(Duration.ofSeconds(this.requestTimeoutSeconds)); if (isAuthCredentialsSet()) { requestBuilder.header("Authorization", this.auth.getAuthHeaderValue()); diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java index 461ec75..a63a384 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java @@ -2,11 +2,11 @@ package io.github.ollama4j.models.request; import com.fasterxml.jackson.core.JsonProcessingException; import io.github.ollama4j.exceptions.OllamaBaseException; -import io.github.ollama4j.models.response.OllamaErrorResponse; -import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.generate.OllamaGenerateResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver; import io.github.ollama4j.models.generate.OllamaStreamHandler; +import io.github.ollama4j.models.response.OllamaErrorResponse; +import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.utils.OllamaRequestBody; import io.github.ollama4j.utils.Utils; import org.slf4j.Logger; @@ -22,11 +22,12 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.charset.StandardCharsets; +@SuppressWarnings("resource") public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class); - private OllamaGenerateStreamObserver streamObserver; + private OllamaGenerateStreamObserver responseStreamObserver; public OllamaGenerateEndpointCaller(String host, Auth basicAuth, long requestTimeoutSeconds, boolean verbose) { super(host, basicAuth, requestTimeoutSeconds, verbose); @@ -38,12 +39,17 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { } @Override - protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { + protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) { try { OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); - responseBuffer.append(ollamaResponseModel.getResponse()); - if (streamObserver != null) { - streamObserver.notify(ollamaResponseModel); + if (ollamaResponseModel.getResponse() != null) { + responseBuffer.append(ollamaResponseModel.getResponse()); + } + if (ollamaResponseModel.getThinking() != null) { + thinkingBuffer.append(ollamaResponseModel.getThinking()); + } + if (responseStreamObserver != null) { + responseStreamObserver.notify(ollamaResponseModel); } return ollamaResponseModel.isDone(); } catch (JsonProcessingException e) { @@ -52,9 +58,8 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { } } - public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) - throws OllamaBaseException, IOException, InterruptedException { - streamObserver = new OllamaGenerateStreamObserver(streamHandler); + public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler thinkingStreamHandler, OllamaStreamHandler responseStreamHandler) throws OllamaBaseException, IOException, InterruptedException { + responseStreamObserver = new OllamaGenerateStreamObserver(thinkingStreamHandler, responseStreamHandler); return callSync(body); } @@ -67,46 +72,41 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { * @throws IOException in case the responseStream can not be read * @throws InterruptedException in case the server is not reachable or network issues happen */ + @SuppressWarnings("DuplicatedCode") public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException { // Create Request long startTime = System.currentTimeMillis(); HttpClient httpClient = HttpClient.newHttpClient(); URI uri = URI.create(getHost() + getEndpointSuffix()); - HttpRequest.Builder requestBuilder = - getRequestBuilderDefault(uri) - .POST( - body.getBodyPublisher()); + HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).POST(body.getBodyPublisher()); HttpRequest request = requestBuilder.build(); - if (isVerbose()) LOG.info("Asking model: " + body.toString()); - HttpResponse response = - httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + if (isVerbose()) LOG.info("Asking model: {}", body); + HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); InputStream responseBodyStream = response.body(); StringBuilder responseBuffer = new StringBuilder(); - try (BufferedReader reader = - new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + StringBuilder thinkingBuffer = new StringBuilder(); + OllamaGenerateResponseModel ollamaGenerateResponseModel = null; + try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { if (statusCode == 404) { LOG.warn("Status code: 404 (Not Found)"); - OllamaErrorResponse ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); + OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); responseBuffer.append(ollamaResponseModel.getError()); } else if (statusCode == 401) { LOG.warn("Status code: 401 (Unauthorized)"); - OllamaErrorResponse ollamaResponseModel = - Utils.getObjectMapper() - .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class); + OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class); responseBuffer.append(ollamaResponseModel.getError()); } else if (statusCode == 400) { LOG.warn("Status code: 400 (Bad Request)"); - OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, - OllamaErrorResponse.class); + OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); responseBuffer.append(ollamaResponseModel.getError()); } else { - boolean finished = parseResponseAndAddToBuffer(line, responseBuffer); + boolean finished = parseResponseAndAddToBuffer(line, responseBuffer, thinkingBuffer); if (finished) { + ollamaGenerateResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); break; } } @@ -114,13 +114,25 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { } if (statusCode != 200) { - LOG.error("Status code " + statusCode); + LOG.error("Status code: {}", statusCode); throw new OllamaBaseException(responseBuffer.toString()); } else { long endTime = System.currentTimeMillis(); - OllamaResult ollamaResult = - new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); - if (isVerbose()) LOG.info("Model response: " + ollamaResult); + OllamaResult ollamaResult = new OllamaResult(responseBuffer.toString(), thinkingBuffer.toString(), endTime - startTime, statusCode); + + ollamaResult.setModel(ollamaGenerateResponseModel.getModel()); + ollamaResult.setCreatedAt(ollamaGenerateResponseModel.getCreatedAt()); + ollamaResult.setDone(ollamaGenerateResponseModel.isDone()); + ollamaResult.setDoneReason(ollamaGenerateResponseModel.getDoneReason()); + ollamaResult.setContext(ollamaGenerateResponseModel.getContext()); + ollamaResult.setTotalDuration(ollamaGenerateResponseModel.getTotalDuration()); + ollamaResult.setLoadDuration(ollamaGenerateResponseModel.getLoadDuration()); + ollamaResult.setPromptEvalCount(ollamaGenerateResponseModel.getPromptEvalCount()); + ollamaResult.setPromptEvalDuration(ollamaGenerateResponseModel.getPromptEvalDuration()); + ollamaResult.setEvalCount(ollamaGenerateResponseModel.getEvalCount()); + ollamaResult.setEvalDuration(ollamaGenerateResponseModel.getEvalDuration()); + + if (isVerbose()) LOG.info("Model response: {}", ollamaResult); return ollamaResult; } } diff --git a/src/main/java/io/github/ollama4j/models/response/LibraryModel.java b/src/main/java/io/github/ollama4j/models/response/LibraryModel.java index 82aba42..c5f1627 100644 --- a/src/main/java/io/github/ollama4j/models/response/LibraryModel.java +++ b/src/main/java/io/github/ollama4j/models/response/LibraryModel.java @@ -1,9 +1,10 @@ package io.github.ollama4j.models.response; -import java.util.ArrayList; -import java.util.List; import lombok.Data; +import java.util.ArrayList; +import java.util.List; + @Data public class LibraryModel { diff --git a/src/main/java/io/github/ollama4j/models/response/LibraryModelTag.java b/src/main/java/io/github/ollama4j/models/response/LibraryModelTag.java index d720dd0..cd65d32 100644 --- a/src/main/java/io/github/ollama4j/models/response/LibraryModelTag.java +++ b/src/main/java/io/github/ollama4j/models/response/LibraryModelTag.java @@ -2,8 +2,6 @@ package io.github.ollama4j.models.response; import lombok.Data; -import java.util.List; - @Data public class LibraryModelTag { private String name; diff --git a/src/main/java/io/github/ollama4j/models/response/ListModelsResponse.java b/src/main/java/io/github/ollama4j/models/response/ListModelsResponse.java index 62f151b..e22b796 100644 --- a/src/main/java/io/github/ollama4j/models/response/ListModelsResponse.java +++ b/src/main/java/io/github/ollama4j/models/response/ListModelsResponse.java @@ -1,9 +1,9 @@ package io.github.ollama4j.models.response; -import java.util.List; - import lombok.Data; +import java.util.List; + @Data public class ListModelsResponse { private List models; diff --git a/src/main/java/io/github/ollama4j/models/response/Model.java b/src/main/java/io/github/ollama4j/models/response/Model.java index ae64f38..a616404 100644 --- a/src/main/java/io/github/ollama4j/models/response/Model.java +++ b/src/main/java/io/github/ollama4j/models/response/Model.java @@ -1,13 +1,13 @@ package io.github.ollama4j.models.response; -import java.time.OffsetDateTime; - import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import io.github.ollama4j.utils.Utils; import lombok.Data; +import java.time.OffsetDateTime; + @Data @JsonIgnoreProperties(ignoreUnknown = true) public class Model { diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaAsyncResultStreamer.java b/src/main/java/io/github/ollama4j/models/response/OllamaAsyncResultStreamer.java index fd43696..f4a68f7 100644 --- a/src/main/java/io/github/ollama4j/models/response/OllamaAsyncResultStreamer.java +++ b/src/main/java/io/github/ollama4j/models/response/OllamaAsyncResultStreamer.java @@ -3,6 +3,7 @@ package io.github.ollama4j.models.response; import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.models.generate.OllamaGenerateRequest; import io.github.ollama4j.models.generate.OllamaGenerateResponseModel; +import io.github.ollama4j.utils.Constants; import io.github.ollama4j.utils.Utils; import lombok.Data; import lombok.EqualsAndHashCode; @@ -25,8 +26,10 @@ import java.time.Duration; public class OllamaAsyncResultStreamer extends Thread { private final HttpRequest.Builder requestBuilder; private final OllamaGenerateRequest ollamaRequestModel; - private final OllamaResultStream stream = new OllamaResultStream(); + private final OllamaResultStream thinkingResponseStream = new OllamaResultStream(); + private final OllamaResultStream responseStream = new OllamaResultStream(); private String completeResponse; + private String completeThinkingResponse; /** @@ -53,14 +56,11 @@ public class OllamaAsyncResultStreamer extends Thread { @Getter private long responseTime = 0; - public OllamaAsyncResultStreamer( - HttpRequest.Builder requestBuilder, - OllamaGenerateRequest ollamaRequestModel, - long requestTimeoutSeconds) { + public OllamaAsyncResultStreamer(HttpRequest.Builder requestBuilder, OllamaGenerateRequest ollamaRequestModel, long requestTimeoutSeconds) { this.requestBuilder = requestBuilder; this.ollamaRequestModel = ollamaRequestModel; this.completeResponse = ""; - this.stream.add(""); + this.responseStream.add(""); this.requestTimeoutSeconds = requestTimeoutSeconds; } @@ -68,47 +68,63 @@ public class OllamaAsyncResultStreamer extends Thread { public void run() { ollamaRequestModel.setStream(true); HttpClient httpClient = HttpClient.newHttpClient(); + long startTime = System.currentTimeMillis(); try { - long startTime = System.currentTimeMillis(); - HttpRequest request = - requestBuilder - .POST( - HttpRequest.BodyPublishers.ofString( - Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) - .header("Content-Type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)) - .build(); - HttpResponse response = - httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + HttpRequest request = requestBuilder.POST(HttpRequest.BodyPublishers.ofString(Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).timeout(Duration.ofSeconds(requestTimeoutSeconds)).build(); + HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); this.httpStatusCode = statusCode; InputStream responseBodyStream = response.body(); - try (BufferedReader reader = - new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + BufferedReader reader = null; + try { + reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8)); String line; + StringBuilder thinkingBuffer = new StringBuilder(); StringBuilder responseBuffer = new StringBuilder(); while ((line = reader.readLine()) != null) { if (statusCode == 404) { - OllamaErrorResponse ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); - stream.add(ollamaResponseModel.getError()); + OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class); + responseStream.add(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError()); } else { - OllamaGenerateResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); - String res = ollamaResponseModel.getResponse(); - stream.add(res); + OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); + String thinkingTokens = ollamaResponseModel.getThinking(); + String responseTokens = ollamaResponseModel.getResponse(); + if (thinkingTokens == null) { + thinkingTokens = ""; + } + if (responseTokens == null) { + responseTokens = ""; + } + thinkingResponseStream.add(thinkingTokens); + responseStream.add(responseTokens); if (!ollamaResponseModel.isDone()) { - responseBuffer.append(res); + responseBuffer.append(responseTokens); + thinkingBuffer.append(thinkingTokens); } } } - this.succeeded = true; + this.completeThinkingResponse = thinkingBuffer.toString(); this.completeResponse = responseBuffer.toString(); long endTime = System.currentTimeMillis(); responseTime = endTime - startTime; + } finally { + if (reader != null) { + try { + reader.close(); + } catch (IOException e) { + // Optionally log or handle + } + } + if (responseBodyStream != null) { + try { + responseBodyStream.close(); + } catch (IOException e) { + // Optionally log or handle + } + } } if (statusCode != 200) { throw new OllamaBaseException(this.completeResponse); diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java index 4b538f9..ce6d5e3 100644 --- a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java +++ b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java @@ -3,116 +3,136 @@ package io.github.ollama4j.models.response; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; - import lombok.Data; import lombok.Getter; -import static io.github.ollama4j.utils.Utils.getObjectMapper; - import java.util.HashMap; +import java.util.List; import java.util.Map; -/** The type Ollama result. */ +import static io.github.ollama4j.utils.Utils.getObjectMapper; + +/** + * The type Ollama result. + */ @Getter @SuppressWarnings("unused") @Data @JsonIgnoreProperties(ignoreUnknown = true) public class OllamaResult { - /** - * -- GETTER -- - * Get the completion/response text - * - * @return String completion/response text - */ - private final String response; + /** + * Get the completion/response text + */ + private final String response; + /** + * Get the thinking text (if available) + */ + private final String thinking; + /** + * Get the response status code. + */ + private int httpStatusCode; + /** + * Get the response time in milliseconds. + */ + private long responseTime = 0; - /** - * -- GETTER -- - * Get the response status code. - * - * @return int - response status code - */ - private int httpStatusCode; + private String model; + private String createdAt; + private boolean done; + private String doneReason; + private List context; + private Long totalDuration; + private Long loadDuration; + private Integer promptEvalCount; + private Long promptEvalDuration; + private Integer evalCount; + private Long evalDuration; - /** - * -- GETTER -- - * Get the response time in milliseconds. - * - * @return long - response time in milliseconds - */ - private long responseTime = 0; - - public OllamaResult(String response, long responseTime, int httpStatusCode) { - this.response = response; - this.responseTime = responseTime; - this.httpStatusCode = httpStatusCode; - } - - @Override - public String toString() { - try { - Map responseMap = new HashMap<>(); - responseMap.put("response", this.response); - responseMap.put("httpStatusCode", this.httpStatusCode); - responseMap.put("responseTime", this.responseTime); - return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - /** - * Get the structured response if the response is a JSON object. - * - * @return Map - structured response - * @throws IllegalArgumentException if the response is not a valid JSON object - */ - public Map getStructuredResponse() { - String responseStr = this.getResponse(); - if (responseStr == null || responseStr.trim().isEmpty()) { - throw new IllegalArgumentException("Response is empty or null"); + public OllamaResult(String response, String thinking, long responseTime, int httpStatusCode) { + this.response = response; + this.thinking = thinking; + this.responseTime = responseTime; + this.httpStatusCode = httpStatusCode; } - try { - // Check if the response is a valid JSON - if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || - (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { - throw new IllegalArgumentException("Response is not a valid JSON object"); - } - - Map response = getObjectMapper().readValue(responseStr, - new TypeReference>() { - }); - return response; - } catch (JsonProcessingException e) { - throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); - } - } - - /** - * Get the structured response mapped to a specific class type. - * - * @param The type of class to map the response to - * @param clazz The class to map the response to - * @return An instance of the specified class with the response data - * @throws IllegalArgumentException if the response is not a valid JSON or is empty - * @throws RuntimeException if there is an error mapping the response - */ - public T as(Class clazz) { - String responseStr = this.getResponse(); - if (responseStr == null || responseStr.trim().isEmpty()) { - throw new IllegalArgumentException("Response is empty or null"); + @Override + public String toString() { + try { + Map responseMap = new HashMap<>(); + responseMap.put("response", this.response); + responseMap.put("thinking", this.thinking); + responseMap.put("httpStatusCode", this.httpStatusCode); + responseMap.put("responseTime", this.responseTime); + responseMap.put("model", this.model); + responseMap.put("createdAt", this.createdAt); + responseMap.put("done", this.done); + responseMap.put("doneReason", this.doneReason); + responseMap.put("context", this.context); + responseMap.put("totalDuration", this.totalDuration); + responseMap.put("loadDuration", this.loadDuration); + responseMap.put("promptEvalCount", this.promptEvalCount); + responseMap.put("promptEvalDuration", this.promptEvalDuration); + responseMap.put("evalCount", this.evalCount); + responseMap.put("evalDuration", this.evalDuration); + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } } - try { - // Check if the response is a valid JSON - if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || - (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { - throw new IllegalArgumentException("Response is not a valid JSON object"); - } - return getObjectMapper().readValue(responseStr, clazz); - } catch (JsonProcessingException e) { - throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); + /** + * Get the structured response if the response is a JSON object. + * + * @return Map - structured response + * @throws IllegalArgumentException if the response is not a valid JSON object + */ + public Map getStructuredResponse() { + String responseStr = this.getResponse(); + if (responseStr == null || responseStr.trim().isEmpty()) { + throw new IllegalArgumentException("Response is empty or null"); + } + + try { + // Check if the response is a valid JSON + if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || + (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { + throw new IllegalArgumentException("Response is not a valid JSON object"); + } + + Map response = getObjectMapper().readValue(responseStr, + new TypeReference>() { + }); + return response; + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); + } + } + + /** + * Get the structured response mapped to a specific class type. + * + * @param The type of class to map the response to + * @param clazz The class to map the response to + * @return An instance of the specified class with the response data + * @throws IllegalArgumentException if the response is not a valid JSON or is empty + * @throws RuntimeException if there is an error mapping the response + */ + public T as(Class clazz) { + String responseStr = this.getResponse(); + if (responseStr == null || responseStr.trim().isEmpty()) { + throw new IllegalArgumentException("Response is empty or null"); + } + + try { + // Check if the response is a valid JSON + if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || + (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { + throw new IllegalArgumentException("Response is not a valid JSON object"); + } + return getObjectMapper().readValue(responseStr, clazz); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); + } } - } } diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java index 9ae3e71..01bf446 100644 --- a/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java +++ b/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java @@ -1,19 +1,18 @@ package io.github.ollama4j.models.response; -import static io.github.ollama4j.utils.Utils.getObjectMapper; - -import java.util.Map; - -import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; - import lombok.Data; import lombok.Getter; import lombok.NoArgsConstructor; +import java.util.List; +import java.util.Map; + +import static io.github.ollama4j.utils.Utils.getObjectMapper; + @Getter @SuppressWarnings("unused") @Data @@ -21,13 +20,22 @@ import lombok.NoArgsConstructor; @JsonIgnoreProperties(ignoreUnknown = true) public class OllamaStructuredResult { private String response; - + private String thinking; private int httpStatusCode; - private long responseTime = 0; - private String model; + private @JsonProperty("created_at") String createdAt; + private boolean done; + private @JsonProperty("done_reason") String doneReason; + private List context; + private @JsonProperty("total_duration") Long totalDuration; + private @JsonProperty("load_duration") Long loadDuration; + private @JsonProperty("prompt_eval_count") Integer promptEvalCount; + private @JsonProperty("prompt_eval_duration") Long promptEvalDuration; + private @JsonProperty("eval_count") Integer evalCount; + private @JsonProperty("eval_duration") Long evalDuration; + public OllamaStructuredResult(String response, long responseTime, int httpStatusCode) { this.response = response; this.responseTime = responseTime; diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaVersion.java b/src/main/java/io/github/ollama4j/models/response/OllamaVersion.java index eac177b..11b7524 100644 --- a/src/main/java/io/github/ollama4j/models/response/OllamaVersion.java +++ b/src/main/java/io/github/ollama4j/models/response/OllamaVersion.java @@ -2,8 +2,6 @@ package io.github.ollama4j.models.response; import lombok.Data; -import java.util.List; - @Data public class OllamaVersion { private String version; diff --git a/src/main/java/io/github/ollama4j/tools/ToolRegistry.java b/src/main/java/io/github/ollama4j/tools/ToolRegistry.java index 5ab8be3..b106042 100644 --- a/src/main/java/io/github/ollama4j/tools/ToolRegistry.java +++ b/src/main/java/io/github/ollama4j/tools/ToolRegistry.java @@ -9,14 +9,21 @@ public class ToolRegistry { public ToolFunction getToolFunction(String name) { final Tools.ToolSpecification toolSpecification = tools.get(name); - return toolSpecification !=null ? toolSpecification.getToolFunction() : null ; + return toolSpecification != null ? toolSpecification.getToolFunction() : null; } - public void addTool (String name, Tools.ToolSpecification specification) { + public void addTool(String name, Tools.ToolSpecification specification) { tools.put(name, specification); } - public Collection getRegisteredSpecs(){ + public Collection getRegisteredSpecs() { return tools.values(); } + + /** + * Removes all registered tools from the registry. + */ + public void clear() { + tools.clear(); + } } diff --git a/src/main/java/io/github/ollama4j/tools/sampletools/WeatherTool.java b/src/main/java/io/github/ollama4j/tools/sampletools/WeatherTool.java index eb0ba72..7a32ab0 100644 --- a/src/main/java/io/github/ollama4j/tools/sampletools/WeatherTool.java +++ b/src/main/java/io/github/ollama4j/tools/sampletools/WeatherTool.java @@ -1,88 +1,54 @@ package io.github.ollama4j.tools.sampletools; -import java.io.IOException; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.util.Map; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.github.ollama4j.tools.Tools; +import java.util.Map; + +@SuppressWarnings("resource") public class WeatherTool { - private String openWeatherMapAPIKey = null; + private String paramCityName = "cityName"; - public WeatherTool(String openWeatherMapAPIKey) { - this.openWeatherMapAPIKey = openWeatherMapAPIKey; - } + public WeatherTool() { + } - public String getCurrentWeather(Map arguments) { - String city = (String) arguments.get("cityName"); - System.out.println("Finding weather for city: " + city); + public String getCurrentWeather(Map arguments) { + String city = (String) arguments.get(paramCityName); + return "It is sunny in " + city; + } - String url = String.format("https://api.openweathermap.org/data/2.5/weather?q=%s&appid=%s&units=metric", - city, - this.openWeatherMapAPIKey); - - HttpClient client = HttpClient.newHttpClient(); - HttpRequest request = HttpRequest.newBuilder() - .uri(URI.create(url)) - .build(); - try { - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); - if (response.statusCode() == 200) { - ObjectMapper mapper = new ObjectMapper(); - JsonNode root = mapper.readTree(response.body()); - JsonNode main = root.path("main"); - double temperature = main.path("temp").asDouble(); - String description = root.path("weather").get(0).path("description").asText(); - return String.format("Weather in %s: %.1f°C, %s", city, temperature, description); - } else { - return "Could not retrieve weather data for " + city + ". Status code: " - + response.statusCode(); - } - } catch (IOException | InterruptedException e) { - e.printStackTrace(); - return "Error retrieving weather data: " + e.getMessage(); - } - } - - public Tools.ToolSpecification getSpecification() { - return Tools.ToolSpecification.builder() - .functionName("weather-reporter") - .functionDescription( - "You are a tool who simply finds the city name from the user's message input/query about weather.") - .toolFunction(this::getCurrentWeather) - .toolPrompt( - Tools.PromptFuncDefinition.builder() - .type("prompt") - .function( - Tools.PromptFuncDefinition.PromptFuncSpec - .builder() - .name("get-city-name") - .description("Get the city name") - .parameters( - Tools.PromptFuncDefinition.Parameters - .builder() - .type("object") - .properties( - Map.of( - "cityName", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description( - "The name of the city. e.g. Bengaluru") - .required(true) - .build())) - .required(java.util.List - .of("cityName")) - .build()) - .build()) + public Tools.ToolSpecification getSpecification() { + return Tools.ToolSpecification.builder() + .functionName("weather-reporter") + .functionDescription( + "You are a tool who simply finds the city name from the user's message input/query about weather.") + .toolFunction(this::getCurrentWeather) + .toolPrompt( + Tools.PromptFuncDefinition.builder() + .type("prompt") + .function( + Tools.PromptFuncDefinition.PromptFuncSpec + .builder() + .name("get-city-name") + .description("Get the city name") + .parameters( + Tools.PromptFuncDefinition.Parameters + .builder() + .type("object") + .properties( + Map.of( + paramCityName, + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description( + "The name of the city. e.g. Bengaluru") + .required(true) + .build())) + .required(java.util.List + .of(paramCityName)) .build()) - .build(); - } + .build()) + .build()) + .build(); + } } diff --git a/src/main/java/io/github/ollama4j/utils/BooleanToJsonFormatFlagSerializer.java b/src/main/java/io/github/ollama4j/utils/BooleanToJsonFormatFlagSerializer.java index a94e4d1..590b59e 100644 --- a/src/main/java/io/github/ollama4j/utils/BooleanToJsonFormatFlagSerializer.java +++ b/src/main/java/io/github/ollama4j/utils/BooleanToJsonFormatFlagSerializer.java @@ -1,11 +1,11 @@ package io.github.ollama4j.utils; -import java.io.IOException; - import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.SerializerProvider; +import java.io.IOException; + public class BooleanToJsonFormatFlagSerializer extends JsonSerializer{ @Override diff --git a/src/main/java/io/github/ollama4j/utils/Constants.java b/src/main/java/io/github/ollama4j/utils/Constants.java new file mode 100644 index 0000000..dfe5377 --- /dev/null +++ b/src/main/java/io/github/ollama4j/utils/Constants.java @@ -0,0 +1,14 @@ +package io.github.ollama4j.utils; + +public final class Constants { + public static final class HttpConstants { + private HttpConstants() { + } + + public static final String APPLICATION_JSON = "application/json"; + public static final String APPLICATION_XML = "application/xml"; + public static final String TEXT_PLAIN = "text/plain"; + public static final String HEADER_KEY_CONTENT_TYPE = "Content-Type"; + public static final String HEADER_KEY_ACCEPT = "Accept"; + } +} diff --git a/src/main/java/io/github/ollama4j/utils/FileToBase64Serializer.java b/src/main/java/io/github/ollama4j/utils/FileToBase64Serializer.java index b8b05e5..c54d83f 100644 --- a/src/main/java/io/github/ollama4j/utils/FileToBase64Serializer.java +++ b/src/main/java/io/github/ollama4j/utils/FileToBase64Serializer.java @@ -1,13 +1,13 @@ package io.github.ollama4j.utils; -import java.io.IOException; -import java.util.Base64; -import java.util.Collection; - import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.SerializerProvider; +import java.io.IOException; +import java.util.Base64; +import java.util.Collection; + public class FileToBase64Serializer extends JsonSerializer> { @Override diff --git a/src/main/java/io/github/ollama4j/utils/OllamaRequestBody.java b/src/main/java/io/github/ollama4j/utils/OllamaRequestBody.java index 1dc2265..805cec4 100644 --- a/src/main/java/io/github/ollama4j/utils/OllamaRequestBody.java +++ b/src/main/java/io/github/ollama4j/utils/OllamaRequestBody.java @@ -1,11 +1,11 @@ package io.github.ollama4j.utils; -import java.net.http.HttpRequest.BodyPublisher; -import java.net.http.HttpRequest.BodyPublishers; - import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.core.JsonProcessingException; +import java.net.http.HttpRequest.BodyPublisher; +import java.net.http.HttpRequest.BodyPublishers; + /** * Interface to represent a OllamaRequest as HTTP-Request Body via {@link BodyPublishers}. */ diff --git a/src/main/java/io/github/ollama4j/utils/Options.java b/src/main/java/io/github/ollama4j/utils/Options.java index c6e5e53..c4ea79d 100644 --- a/src/main/java/io/github/ollama4j/utils/Options.java +++ b/src/main/java/io/github/ollama4j/utils/Options.java @@ -1,8 +1,9 @@ package io.github.ollama4j.utils; -import java.util.Map; import lombok.Data; +import java.util.Map; + /** Class for options for Ollama model. */ @Data public class Options { diff --git a/src/main/java/io/github/ollama4j/utils/OptionsBuilder.java b/src/main/java/io/github/ollama4j/utils/OptionsBuilder.java index 4148170..6ee8392 100644 --- a/src/main/java/io/github/ollama4j/utils/OptionsBuilder.java +++ b/src/main/java/io/github/ollama4j/utils/OptionsBuilder.java @@ -1,6 +1,5 @@ package io.github.ollama4j.utils; -import java.io.IOException; import java.util.HashMap; /** Builder class for creating options for Ollama model. */ diff --git a/src/main/java/io/github/ollama4j/utils/SamplePrompts.java b/src/main/java/io/github/ollama4j/utils/SamplePrompts.java deleted file mode 100644 index 89a7f83..0000000 --- a/src/main/java/io/github/ollama4j/utils/SamplePrompts.java +++ /dev/null @@ -1,25 +0,0 @@ -package io.github.ollama4j.utils; - -import io.github.ollama4j.OllamaAPI; - -import java.io.InputStream; -import java.util.Scanner; - -public class SamplePrompts { - public static String getSampleDatabasePromptWithQuestion(String question) throws Exception { - ClassLoader classLoader = OllamaAPI.class.getClassLoader(); - InputStream inputStream = classLoader.getResourceAsStream("sample-db-prompt-template.txt"); - if (inputStream != null) { - Scanner scanner = new Scanner(inputStream); - StringBuilder stringBuffer = new StringBuilder(); - while (scanner.hasNextLine()) { - stringBuffer.append(scanner.nextLine()).append("\n"); - } - scanner.close(); - return stringBuffer.toString().replaceAll("", question); - } else { - throw new Exception("Sample database question file not found."); - } - } - -} diff --git a/src/main/java/io/github/ollama4j/utils/Utils.java b/src/main/java/io/github/ollama4j/utils/Utils.java index d854df1..6d2aa5e 100644 --- a/src/main/java/io/github/ollama4j/utils/Utils.java +++ b/src/main/java/io/github/ollama4j/utils/Utils.java @@ -1,38 +1,45 @@ package io.github.ollama4j.utils; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; + import java.io.ByteArrayOutputStream; +import java.io.File; import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.net.URISyntaxException; import java.net.URL; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import java.util.Objects; public class Utils { - private static ObjectMapper objectMapper; + private static ObjectMapper objectMapper; - public static ObjectMapper getObjectMapper() { - if(objectMapper == null) { - objectMapper = new ObjectMapper(); - objectMapper.registerModule(new JavaTimeModule()); + public static ObjectMapper getObjectMapper() { + if (objectMapper == null) { + objectMapper = new ObjectMapper(); + objectMapper.registerModule(new JavaTimeModule()); + } + return objectMapper; } - return objectMapper; - } - public static byte[] loadImageBytesFromUrl(String imageUrl) - throws IOException, URISyntaxException { - URL url = new URI(imageUrl).toURL(); - try (InputStream in = url.openStream(); - ByteArrayOutputStream out = new ByteArrayOutputStream()) { - byte[] buffer = new byte[1024]; - int bytesRead; - while ((bytesRead = in.read(buffer)) != -1) { - out.write(buffer, 0, bytesRead); - } - return out.toByteArray(); + public static byte[] loadImageBytesFromUrl(String imageUrl) + throws IOException, URISyntaxException { + URL url = new URI(imageUrl).toURL(); + try (InputStream in = url.openStream(); + ByteArrayOutputStream out = new ByteArrayOutputStream()) { + byte[] buffer = new byte[1024]; + int bytesRead; + while ((bytesRead = in.read(buffer)) != -1) { + out.write(buffer, 0, bytesRead); + } + return out.toByteArray(); + } + } + + public static File getFileFromClasspath(String fileName) { + ClassLoader classLoader = Utils.class.getClassLoader(); + return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); } - } } diff --git a/src/main/resources/sample-db-prompt-template.txt b/src/main/resources/sample-db-prompt-template.txt deleted file mode 100644 index 177f648..0000000 --- a/src/main/resources/sample-db-prompt-template.txt +++ /dev/null @@ -1,61 +0,0 @@ -""" -Following is the database schema. - -DROP TABLE IF EXISTS product_categories; -CREATE TABLE IF NOT EXISTS product_categories -( - category_id INTEGER PRIMARY KEY, -- Unique ID for each category - name VARCHAR(50), -- Name of the category - parent INTEGER NULL, -- Parent category - for hierarchical categories - FOREIGN KEY (parent) REFERENCES product_categories (category_id) -); -DROP TABLE IF EXISTS products; -CREATE TABLE IF NOT EXISTS products -( - product_id INTEGER PRIMARY KEY, -- Unique ID for each product - name VARCHAR(50), -- Name of the product - price DECIMAL(10, 2), -- Price of each unit of the product - quantity INTEGER, -- Current quantity in stock - category_id INTEGER, -- Unique ID for each product - FOREIGN KEY (category_id) REFERENCES product_categories (category_id) -); -DROP TABLE IF EXISTS customers; -CREATE TABLE IF NOT EXISTS customers -( - customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer - name VARCHAR(50), -- Name of the customer - address VARCHAR(100) -- Mailing address of the customer -); -DROP TABLE IF EXISTS salespeople; -CREATE TABLE IF NOT EXISTS salespeople -( - salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson - name VARCHAR(50), -- Name of the salesperson - region VARCHAR(50) -- Geographic sales region -); -DROP TABLE IF EXISTS sales; -CREATE TABLE IF NOT EXISTS sales -( - sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale - product_id INTEGER, -- ID of product sold - customer_id INTEGER, -- ID of customer who made the purchase - salesperson_id INTEGER, -- ID of salesperson who made the sale - sale_date DATE, -- Date the sale occurred - quantity INTEGER, -- Quantity of product sold - FOREIGN KEY (product_id) REFERENCES products (product_id), - FOREIGN KEY (customer_id) REFERENCES customers (customer_id), - FOREIGN KEY (salesperson_id) REFERENCES salespeople (salesperson_id) -); -DROP TABLE IF EXISTS product_suppliers; -CREATE TABLE IF NOT EXISTS product_suppliers -( - supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier - product_id INTEGER, -- Product ID supplied - supply_price DECIMAL(10, 2), -- Unit price charged by supplier - FOREIGN KEY (product_id) REFERENCES products (product_id) -); - - -Generate only a valid (syntactically correct) executable Postgres SQL query (without any explanation of the query) for the following question: -``: -""" \ No newline at end of file diff --git a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java index abe388c..497fe9c 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java +++ b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java @@ -1,6 +1,5 @@ package io.github.ollama4j.integrationtests; -import com.fasterxml.jackson.annotation.JsonProperty; import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.ToolInvocationException; @@ -16,9 +15,6 @@ import io.github.ollama4j.tools.ToolFunction; import io.github.ollama4j.tools.Tools; import io.github.ollama4j.tools.annotations.OllamaToolService; import io.github.ollama4j.utils.OptionsBuilder; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; import org.junit.jupiter.api.Order; @@ -40,24 +36,24 @@ import static org.junit.jupiter.api.Assertions.*; @TestMethodOrder(OrderAnnotation.class) @SuppressWarnings({"HttpUrlsUsage", "SpellCheckingInspection"}) -public class OllamaAPIIntegrationTest { +class OllamaAPIIntegrationTest { private static final Logger LOG = LoggerFactory.getLogger(OllamaAPIIntegrationTest.class); private static OllamaContainer ollama; private static OllamaAPI api; - private static final String EMBEDDING_MODEL_MINILM = "all-minilm"; - private static final String CHAT_MODEL_QWEN_SMALL = "qwen2.5:0.5b"; - private static final String CHAT_MODEL_INSTRUCT = "qwen2.5:0.5b-instruct"; - private static final String CHAT_MODEL_SYSTEM_PROMPT = "llama3.2:1b"; - private static final String CHAT_MODEL_LLAMA3 = "llama3"; - private static final String IMAGE_MODEL_LLAVA = "llava"; + private static final String EMBEDDING_MODEL = "all-minilm"; + private static final String VISION_MODEL = "moondream:1.8b"; + private static final String THINKING_TOOL_MODEL = "gpt-oss:20b"; + private static final String GENERAL_PURPOSE_MODEL = "gemma3:270m"; + private static final String TOOLS_MODEL = "mistral:7b"; @BeforeAll - public static void setUp() { + static void setUp() { try { boolean useExternalOllamaHost = Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST")); String ollamaHost = System.getenv("OLLAMA_HOST"); + if (useExternalOllamaHost) { LOG.info("Using external Ollama host..."); api = new OllamaAPI(ollamaHost); @@ -80,7 +76,7 @@ public class OllamaAPIIntegrationTest { } api.setRequestTimeoutSeconds(120); api.setVerbose(true); - api.setNumberOfRetriesForModelPull(3); + api.setNumberOfRetriesForModelPull(5); } @Test @@ -92,7 +88,7 @@ public class OllamaAPIIntegrationTest { @Test @Order(1) - public void testVersionAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + void testVersionAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { // String expectedVersion = ollama.getDockerImageName().split(":")[1]; String actualVersion = api.getVersion(); assertNotNull(actualVersion); @@ -100,17 +96,22 @@ public class OllamaAPIIntegrationTest { // image version"); } + @Test + @Order(1) + void testPing() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + boolean pingResponse = api.ping(); + assertTrue(pingResponse, "Ping should return true"); + } + @Test @Order(2) - public void testListModelsAPI() - throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { - api.pullModel(EMBEDDING_MODEL_MINILM); + void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { // Fetch the list of models List models = api.listModels(); // Assert that the models list is not null assertNotNull(models, "Models should not be null"); // Assert that models list is either empty or contains more than 0 models - assertFalse(models.isEmpty(), "Models list should not be empty"); + assertTrue(models.size() >= 0, "Models list should not be empty"); } @Test @@ -124,9 +125,8 @@ public class OllamaAPIIntegrationTest { @Test @Order(3) - public void testPullModelAPI() - throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { - api.pullModel(EMBEDDING_MODEL_MINILM); + void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + api.pullModel(EMBEDDING_MODEL); List models = api.listModels(); assertNotNull(models, "Models should not be null"); assertFalse(models.isEmpty(), "Models list should contain elements"); @@ -135,17 +135,17 @@ public class OllamaAPIIntegrationTest { @Test @Order(4) void testListModelDetails() throws IOException, OllamaBaseException, URISyntaxException, InterruptedException { - api.pullModel(EMBEDDING_MODEL_MINILM); - ModelDetail modelDetails = api.getModelDetails(EMBEDDING_MODEL_MINILM); + api.pullModel(EMBEDDING_MODEL); + ModelDetail modelDetails = api.getModelDetails(EMBEDDING_MODEL); assertNotNull(modelDetails); - assertTrue(modelDetails.getModelFile().contains(EMBEDDING_MODEL_MINILM)); + assertTrue(modelDetails.getModelFile().contains(EMBEDDING_MODEL)); } @Test @Order(5) - public void testEmbeddings() throws Exception { - api.pullModel(EMBEDDING_MODEL_MINILM); - OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL_MINILM, + void testEmbeddings() throws Exception { + api.pullModel(EMBEDDING_MODEL); + OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL, Arrays.asList("Why is the sky blue?", "Why is the grass green?")); assertNotNull(embeddings, "Embeddings should not be null"); assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty"); @@ -153,58 +153,44 @@ public class OllamaAPIIntegrationTest { @Test @Order(6) - void testAskModelWithStructuredOutput() + void testGenerateWithStructuredOutput() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { - api.pullModel(CHAT_MODEL_LLAMA3); + api.pullModel(TOOLS_MODEL); - int timeHour = 6; - boolean isNightTime = false; - - String prompt = "The Sun is shining, and its " + timeHour + ". Its daytime."; + String prompt = "The sun is shining brightly and is directly overhead at the zenith, casting my shadow over my foot, so it must be noon."; Map format = new HashMap<>(); format.put("type", "object"); format.put("properties", new HashMap() { { - put("timeHour", new HashMap() { - { - put("type", "integer"); - } - }); - put("isNightTime", new HashMap() { + put("isNoon", new HashMap() { { put("type", "boolean"); } }); } }); - format.put("required", Arrays.asList("timeHour", "isNightTime")); + format.put("required", List.of("isNoon")); - OllamaResult result = api.generate(CHAT_MODEL_LLAMA3, prompt, format); + OllamaResult result = api.generate(TOOLS_MODEL, prompt, format); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); - assertEquals(timeHour, - result.getStructuredResponse().get("timeHour")); - assertEquals(isNightTime, - result.getStructuredResponse().get("isNightTime")); - - TimeOfDay timeOfDay = result.as(TimeOfDay.class); - - assertEquals(timeHour, timeOfDay.getTimeHour()); - assertEquals(isNightTime, timeOfDay.isNightTime()); + assertEquals(true, result.getStructuredResponse().get("isNoon")); } @Test @Order(6) - void testAskModelWithDefaultOptions() + void testGennerateModelWithDefaultOptions() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { - api.pullModel(CHAT_MODEL_QWEN_SMALL); - OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, - "What is the capital of France? And what's France's connection with Mona Lisa?", false, - new OptionsBuilder().build()); + api.pullModel(GENERAL_PURPOSE_MODEL); + boolean raw = false; + boolean thinking = false; + OllamaResult result = api.generate(GENERAL_PURPOSE_MODEL, + "What is the capital of France? And what's France's connection with Mona Lisa?", raw, + thinking, new OptionsBuilder().build()); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); @@ -212,32 +198,31 @@ public class OllamaAPIIntegrationTest { @Test @Order(7) - void testAskModelWithDefaultOptionsStreamed() + void testGenerateWithDefaultOptionsStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(CHAT_MODEL_QWEN_SMALL); + api.pullModel(GENERAL_PURPOSE_MODEL); + boolean raw = false; StringBuffer sb = new StringBuffer(); - OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, - "What is the capital of France? And what's France's connection with Mona Lisa?", false, + OllamaResult result = api.generate(GENERAL_PURPOSE_MODEL, + "What is the capital of France? And what's France's connection with Mona Lisa?", raw, new OptionsBuilder().build(), (s) -> { LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); + sb.append(s); }); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); - assertEquals(sb.toString().trim(), result.getResponse().trim()); + assertEquals(sb.toString(), result.getResponse()); } @Test @Order(8) - void testAskModelWithOptions() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(CHAT_MODEL_INSTRUCT); + void testGenerateWithOptions() throws OllamaBaseException, IOException, URISyntaxException, + InterruptedException, ToolInvocationException { + api.pullModel(GENERAL_PURPOSE_MODEL); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_INSTRUCT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GENERAL_PURPOSE_MODEL); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].") .build(); @@ -253,29 +238,32 @@ public class OllamaAPIIntegrationTest { @Test @Order(9) - void testChatWithSystemPrompt() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, - "You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!") - .withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?") - .withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build(); + void testChatWithSystemPrompt() throws OllamaBaseException, IOException, URISyntaxException, + InterruptedException, ToolInvocationException { + api.pullModel(GENERAL_PURPOSE_MODEL); + + String expectedResponse = "Bhai"; + + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GENERAL_PURPOSE_MODEL); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, String.format( + "[INSTRUCTION-START] You are an obidient and helpful bot named %s. You always answer with only one word and that word is your name. [INSTRUCTION-END]", + expectedResponse)).withMessage(OllamaChatMessageRole.USER, "Who are you?") + .withOptions(new OptionsBuilder().setTemperature(0.0f).build()).build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank()); - assertTrue(chatResult.getResponseModel().getMessage().getContent().contains("Shush")); + assertTrue(chatResult.getResponseModel().getMessage().getContent().contains(expectedResponse)); assertEquals(3, chatResult.getChatHistory().size()); } @Test @Order(10) - public void testChat() throws Exception { - api.pullModel(CHAT_MODEL_LLAMA3); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3); + void testChat() throws Exception { + api.pullModel(THINKING_TOOL_MODEL); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL); // Create the initial user question OllamaChatRequest requestModel = builder @@ -288,7 +276,6 @@ public class OllamaAPIIntegrationTest { assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), "Expected chat history to contain '2'"); - // Create the next user question: second largest city requestModel = builder.withMessages(chatResult.getChatHistory()) .withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); @@ -299,10 +286,8 @@ public class OllamaAPIIntegrationTest { "Expected chat history to contain '4'"); // Create the next user question: the third question - requestModel = builder.withMessages(chatResult.getChatHistory()) - .withMessage(OllamaChatMessageRole.USER, - "What is the largest value between 2, 4 and 6?") - .build(); + requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, + "What is the largest value between 2, 4 and 6?").build(); // Continue conversation with the model for the third question chatResult = api.chat(requestModel); @@ -312,143 +297,103 @@ public class OllamaAPIIntegrationTest { assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should contain more than two messages"); assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent() - .contains("6"), - "Response should contain '6'"); - } - - @Test - @Order(10) - void testChatWithImageFromURL() - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException { - api.pullModel(IMAGE_MODEL_LLAVA); - - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, "What's in the picture?", - Collections.emptyList(), - "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") - .build(); - api.registerAnnotatedTools(new OllamaAPIIntegrationTest()); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - } - - @Test - @Order(10) - void testChatWithImageFromFileWithHistoryRecognition() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(IMAGE_MODEL_LLAVA); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "What's in the picture?", - Collections.emptyList(), List.of(getImageFileFromClasspath("emoji-smile.jpeg"))) - .build(); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - builder.reset(); - - requestModel = builder.withMessages(chatResult.getChatHistory()) - .withMessage(OllamaChatMessageRole.USER, "What's the color?").build(); - - chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); + .contains("6"), "Response should contain '6'"); } @Test @Order(11) - void testChatWithExplicitToolDefinition() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + void testChatWithExplicitToolDefinition() throws OllamaBaseException, IOException, URISyntaxException, + InterruptedException, ToolInvocationException { + String theToolModel = TOOLS_MODEL; + api.pullModel(theToolModel); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel); - final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() - .functionName("get-employee-details") - .functionDescription("Get employee details from the database") - .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") - .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() - .name("get-employee-details") - .description("Get employee details from the database") - .parameters(Tools.PromptFuncDefinition.Parameters - .builder().type("object") - .properties(new Tools.PropsBuilder() - .withProperty("employee-name", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description("The name of the employee, e.g. John Doe") - .required(true) - .build()) - .withProperty("employee-address", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description( - "The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") - .required(true) - .build()) - .withProperty("employee-phone", - Tools.PromptFuncDefinition.Property - .builder() - .type("string") - .description( - "The phone number of the employee. Always return a random value. e.g. 9911002233") - .required(true) - .build()) - .build()) - .required(List.of("employee-name")) - .build()) - .build()) - .build()) - .toolFunction(arguments -> { - // perform DB operations here - return String.format( - "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", - UUID.randomUUID(), arguments.get("employee-name"), - arguments.get("employee-address"), - arguments.get("employee-phone")); - }).build(); + api.registerTool(employeeFinderTool()); - api.registerTool(databaseQueryToolSpecification); - - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, - "Give me the ID of the employee named 'Rahul Kumar'?") - .build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "Give me the ID and address of the employee Rahul Kumar.").build(); + requestModel.setOptions(new OptionsBuilder().setTemperature(0.9f).build().getOptionsMap()); OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), - chatResult.getResponseModel().getMessage().getRole().getRoleName()); + + assertNotNull(chatResult, "chatResult should not be null"); + assertNotNull(chatResult.getResponseModel(), "Response model should not be null"); + assertNotNull(chatResult.getResponseModel().getMessage(), "Response message should not be null"); + assertEquals( + OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName(), + "Role of the response message should be ASSISTANT" + ); List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); - assertEquals(1, toolCalls.size()); + assertEquals(1, toolCalls.size(), "There should be exactly one tool call in the second chat history message"); OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); - assertEquals("get-employee-details", function.getName()); - assert !function.getArguments().isEmpty(); + assertEquals("get-employee-details", function.getName(), "Tool function name should be 'get-employee-details'"); + assertFalse(function.getArguments().isEmpty(), "Tool function arguments should not be empty"); Object employeeName = function.getArguments().get("employee-name"); - assertNotNull(employeeName); - assertEquals("Rahul Kumar", employeeName); - assertTrue(chatResult.getChatHistory().size() > 2); + assertNotNull(employeeName, "Employee name argument should not be null"); + assertEquals("Rahul Kumar", employeeName, "Employee name argument should be 'Rahul Kumar'"); + assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should have more than 2 messages"); List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); - assertNull(finalToolCalls); + assertNull(finalToolCalls, "Final tool calls in the response message should be null"); + } + + @Test + @Order(14) + void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException, + InterruptedException, ToolInvocationException { + String theToolModel = TOOLS_MODEL; + api.pullModel(theToolModel); + + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel); + + api.registerTool(employeeFinderTool()); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "Give me the ID and address of employee Rahul Kumar") + .withKeepAlive("0m").withOptions(new OptionsBuilder().setTemperature(0.9f).build()) + .build(); + + OllamaChatResult chatResult = api.chat(requestModel, (s) -> { + LOG.info(s.toUpperCase()); + }, (s) -> { + LOG.info(s.toLowerCase()); + }); + + assertNotNull(chatResult, "chatResult should not be null"); + assertNotNull(chatResult.getResponseModel(), "Response model should not be null"); + assertNotNull(chatResult.getResponseModel().getMessage(), "Response message should not be null"); + assertEquals( + OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName(), + "Role of the response message should be ASSISTANT" + ); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); + assertEquals(1, toolCalls.size(), "There should be exactly one tool call in the second chat history message"); + OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); + assertEquals("get-employee-details", function.getName(), "Tool function name should be 'get-employee-details'"); + assertFalse(function.getArguments().isEmpty(), "Tool function arguments should not be empty"); + Object employeeName = function.getArguments().get("employee-name"); + assertNotNull(employeeName, "Employee name argument should not be null"); + assertEquals("Rahul Kumar", employeeName, "Employee name argument should be 'Rahul Kumar'"); + assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should have more than 2 messages"); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls, "Final tool calls in the response message should be null"); } @Test @Order(12) - void testChatWithAnnotatedToolsAndSingleParam() - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException, ToolInvocationException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + void testChatWithAnnotatedToolsAndSingleParam() throws OllamaBaseException, IOException, InterruptedException, + URISyntaxException, ToolInvocationException { + String theToolModel = TOOLS_MODEL; + api.pullModel(theToolModel); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel); api.registerAnnotatedTools(); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "Compute the most important constant in the world using 5 digits").build(); + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Compute the most important constant in the world using 5 digits") + .build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); @@ -471,17 +416,16 @@ public class OllamaAPIIntegrationTest { @Test @Order(13) - void testChatWithAnnotatedToolsAndMultipleParams() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + void testChatWithAnnotatedToolsAndMultipleParams() throws OllamaBaseException, IOException, URISyntaxException, + InterruptedException, ToolInvocationException { + String theToolModel = TOOLS_MODEL; + api.pullModel(theToolModel); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(theToolModel); api.registerAnnotatedTools(new AnnotatedTool()); - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, - "Greet Pedro with a lot of hearts and respond to me, " - + "and state how many emojis have been in your greeting") + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "Greet Rahul with a lot of hearts and respond to me with count of emojis that have been in used in the greeting") .build(); OllamaChatResult chatResult = api.chat(requestModel); @@ -497,28 +441,220 @@ public class OllamaAPIIntegrationTest { assertEquals(2, function.getArguments().size()); Object name = function.getArguments().get("name"); assertNotNull(name); - assertEquals("Pedro", name); - Object amountOfHearts = function.getArguments().get("amountOfHearts"); - assertNotNull(amountOfHearts); - assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1); + assertEquals("Rahul", name); + Object numberOfHearts = function.getArguments().get("numberOfHearts"); + assertNotNull(numberOfHearts); + assertTrue(Integer.parseInt(numberOfHearts.toString()) > 1); assertTrue(chatResult.getChatHistory().size() > 2); List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); assertNull(finalToolCalls); } @Test - @Order(14) - void testChatWithToolsAndStream() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); - final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + @Order(15) + void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, + ToolInvocationException { + api.deregisterTools(); + api.pullModel(GENERAL_PURPOSE_MODEL); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(GENERAL_PURPOSE_MODEL); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "What is the capital of France? And what's France's connection with Mona Lisa?") + .build(); + requestModel.setThink(false); + StringBuffer sb = new StringBuffer(); + + OllamaChatResult chatResult = api.chat(requestModel, (s) -> { + LOG.info(s.toUpperCase()); + sb.append(s); + }, (s) -> { + LOG.info(s.toLowerCase()); + sb.append(s); + }); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertNotNull(chatResult.getResponseModel().getMessage().getContent()); + assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getContent()); + } + + @Test + @Order(15) + void testChatWithThinkingAndStream() throws OllamaBaseException, IOException, URISyntaxException, + InterruptedException, ToolInvocationException { + api.pullModel(THINKING_TOOL_MODEL); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_TOOL_MODEL); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "What is the capital of France? And what's France's connection with Mona Lisa?") + .withThinking(true).withKeepAlive("0m").build(); + StringBuffer sb = new StringBuffer(); + + OllamaChatResult chatResult = api.chat(requestModel, (s) -> { + sb.append(s); + LOG.info(s.toUpperCase()); + }, (s) -> { + sb.append(s); + LOG.info(s.toLowerCase()); + }); + + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertNotNull(chatResult.getResponseModel().getMessage().getContent()); + assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getThinking() + + chatResult.getResponseModel().getMessage().getContent()); + } + + @Test + @Order(10) + void testChatWithImageFromURL() throws OllamaBaseException, IOException, InterruptedException, + URISyntaxException, ToolInvocationException { + api.pullModel(VISION_MODEL); + + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(VISION_MODEL); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "What's in the picture?", Collections.emptyList(), + "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") + .build(); + api.registerAnnotatedTools(new OllamaAPIIntegrationTest()); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + } + + @Test + @Order(10) + void testChatWithImageFromFileWithHistoryRecognition() throws OllamaBaseException, IOException, + URISyntaxException, InterruptedException, ToolInvocationException { + api.pullModel(VISION_MODEL); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(VISION_MODEL); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "What's in the picture?", Collections.emptyList(), + List.of(getImageFileFromClasspath("emoji-smile.jpeg"))).build(); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + builder.reset(); + + requestModel = builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, "What's the color?").build(); + + chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + } + + @Test + @Order(17) + void testGenerateWithOptionsAndImageURLs() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(VISION_MODEL); + + OllamaResult result = api.generateWithImageURLs(VISION_MODEL, "What is in this image?", + List.of("https://i.pinimg.com/736x/f9/4e/cb/f94ecba040696a3a20b484d2e15159ec.jpg"), + new OptionsBuilder().build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + } + + @Test + @Order(18) + void testGenerateWithOptionsAndImageFiles() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(VISION_MODEL); + File imageFile = getImageFileFromClasspath("roses.jpg"); + try { + OllamaResult result = api.generateWithImageFiles(VISION_MODEL, "What is in this image?", + List.of(imageFile), new OptionsBuilder().build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } + } + + @Test + @Order(20) + void testGenerateWithOptionsAndImageFilesStreamed() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(VISION_MODEL); + + File imageFile = getImageFileFromClasspath("roses.jpg"); + + StringBuffer sb = new StringBuffer(); + + OllamaResult result = api.generateWithImageFiles(VISION_MODEL, "What is in this image?", + List.of(imageFile), new OptionsBuilder().build(), (s) -> { + LOG.info(s); + sb.append(s); + }); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertEquals(sb.toString(), result.getResponse()); + } + + @Test + @Order(20) + void testGenerateWithThinking() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(THINKING_TOOL_MODEL); + + boolean raw = false; + boolean think = true; + + OllamaResult result = api.generate(THINKING_TOOL_MODEL, "Who are you?", raw, think, + new OptionsBuilder().build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertNotNull(result.getThinking()); + assertFalse(result.getThinking().isEmpty()); + } + + @Test + @Order(20) + void testGenerateWithThinkingAndStreamHandler() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(THINKING_TOOL_MODEL); + + boolean raw = false; + + StringBuffer sb = new StringBuffer(); + OllamaResult result = api.generate(THINKING_TOOL_MODEL, "Who are you?", raw, + new OptionsBuilder().build(), + (thinkingToken) -> { + sb.append(thinkingToken); + LOG.info(thinkingToken); + }, + (resToken) -> { + sb.append(resToken); + LOG.info(resToken); + } + ); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertNotNull(result.getThinking()); + assertFalse(result.getThinking().isEmpty()); + assertEquals(sb.toString(), result.getThinking() + result.getResponse()); + } + + private File getImageFileFromClasspath(String fileName) { + ClassLoader classLoader = getClass().getClassLoader(); + return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); + } + + private Tools.ToolSpecification employeeFinderTool() { + return Tools.ToolSpecification.builder() .functionName("get-employee-details") - .functionDescription("Get employee details from the database") + .functionDescription("Get details for a person or an employee") .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() .name("get-employee-details") - .description("Get employee details from the database") + .description("Get details for a person or an employee") .parameters(Tools.PromptFuncDefinition.Parameters .builder().type("object") .properties(new Tools.PropsBuilder() @@ -533,16 +669,14 @@ public class OllamaAPIIntegrationTest { Tools.PromptFuncDefinition.Property .builder() .type("string") - .description( - "The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") + .description("The address of the employee, Always eturns a random address. For example, Church St, Bengaluru, India") .required(true) .build()) .withProperty("employee-phone", Tools.PromptFuncDefinition.Property .builder() .type("string") - .description( - "The phone number of the employee. Always return a random value. e.g. 9911002233") + .description("The phone number of the employee. Always returns a random phone number. For example, 9911002233") .required(true) .build()) .build()) @@ -553,129 +687,22 @@ public class OllamaAPIIntegrationTest { .toolFunction(new ToolFunction() { @Override public Object apply(Map arguments) { + LOG.info("Invoking employee finder tool with arguments: {}", arguments); + String employeeName = arguments.get("employee-name").toString(); + String address = null; + String phone = null; + if (employeeName.equalsIgnoreCase("Rahul Kumar")) { + address = "Pune, Maharashtra, India"; + phone = "9911223344"; + } else { + address = "Karol Bagh, Delhi, India"; + phone = "9911002233"; + } // perform DB operations here return String.format( "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", - UUID.randomUUID(), arguments.get("employee-name"), - arguments.get("employee-address"), - arguments.get("employee-phone")); + UUID.randomUUID(), employeeName, address, phone); } }).build(); - - api.registerTool(databaseQueryToolSpecification); - - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, - "Give me the ID of the employee named 'Rahul Kumar'?") - .build(); - - StringBuffer sb = new StringBuffer(); - - OllamaChatResult chatResult = api.chat(requestModel, (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length()); - LOG.info(substring); - sb.append(substring); - }); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertNotNull(chatResult.getResponseModel().getMessage().getContent()); - assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); - } - - @Test - @Order(15) - void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "What is the capital of France? And what's France's connection with Mona Lisa?") - .build(); - - StringBuffer sb = new StringBuffer(); - - OllamaChatResult chatResult = api.chat(requestModel, (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertNotNull(chatResult.getResponseModel().getMessage().getContent()); - assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); - } - - @Test - @Order(17) - void testAskModelWithOptionsAndImageURLs() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(IMAGE_MODEL_LLAVA); - - OllamaResult result = api.generateWithImageURLs(IMAGE_MODEL_LLAVA, "What is in this image?", - List.of("https://upload.wikimedia.org/wikipedia/commons/thumb/a/aa/Noto_Emoji_v2.034_1f642.svg/360px-Noto_Emoji_v2.034_1f642.svg.png"), - new OptionsBuilder().build()); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - } - - @Test - @Order(18) - void testAskModelWithOptionsAndImageFiles() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(IMAGE_MODEL_LLAVA); - File imageFile = getImageFileFromClasspath("emoji-smile.jpeg"); - try { - OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", - List.of(imageFile), - new OptionsBuilder().build()); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); - } - } - - @Test - @Order(20) - void testAskModelWithOptionsAndImageFilesStreamed() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(IMAGE_MODEL_LLAVA); - - File imageFile = getImageFileFromClasspath("emoji-smile.jpeg"); - - StringBuffer sb = new StringBuffer(); - - OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", - List.of(imageFile), - new OptionsBuilder().build(), (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - assertEquals(sb.toString().trim(), result.getResponse().trim()); - } - - private File getImageFileFromClasspath(String fileName) { - ClassLoader classLoader = getClass().getClassLoader(); - return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); } } - -@Data -@AllArgsConstructor -@NoArgsConstructor -class TimeOfDay { - @JsonProperty("timeHour") - private int timeHour; - @JsonProperty("isNightTime") - private boolean nightTime; -} diff --git a/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java b/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java index 6531b27..b349ce3 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java +++ b/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java @@ -24,8 +24,8 @@ import java.io.FileWriter; import java.io.IOException; import java.net.URISyntaxException; import java.time.Duration; -import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.junit.jupiter.api.Assertions.*; @@ -41,7 +41,8 @@ public class WithAuth { private static final String OLLAMA_VERSION = "0.6.1"; private static final String NGINX_VERSION = "nginx:1.23.4-alpine"; private static final String BEARER_AUTH_TOKEN = "secret-token"; - private static final String CHAT_MODEL_LLAMA3 = "llama3"; + private static final String GENERAL_PURPOSE_MODEL = "gemma3:270m"; +// private static final String THINKING_MODEL = "gpt-oss:20b"; private static OllamaContainer ollama; @@ -49,7 +50,7 @@ public class WithAuth { private static OllamaAPI api; @BeforeAll - public static void setUp() { + static void setUp() { ollama = createOllamaContainer(); ollama.start(); @@ -68,7 +69,7 @@ public class WithAuth { LOG.info( "The Ollama service is now accessible via the Nginx proxy with bearer-auth authentication mode.\n" + "→ Ollama URL: {}\n" + - "→ Proxy URL: {}}", + "→ Proxy URL: {}", ollamaUrl, nginxUrl ); LOG.info("OllamaAPI initialized with bearer auth token: {}", BEARER_AUTH_TOKEN); @@ -132,14 +133,14 @@ public class WithAuth { @Test @Order(1) - void testOllamaBehindProxy() throws InterruptedException { + void testOllamaBehindProxy() { api.setBearerAuth(BEARER_AUTH_TOKEN); assertTrue(api.ping(), "Expected OllamaAPI to successfully ping through NGINX with valid auth token."); } @Test @Order(1) - void testWithWrongToken() throws InterruptedException { + void testWithWrongToken() { api.setBearerAuth("wrong-token"); assertFalse(api.ping(), "Expected OllamaAPI ping to fail through NGINX with an invalid auth token."); } @@ -149,46 +150,30 @@ public class WithAuth { void testAskModelWithStructuredOutput() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { api.setBearerAuth(BEARER_AUTH_TOKEN); + String model = GENERAL_PURPOSE_MODEL; + api.pullModel(model); - api.pullModel(CHAT_MODEL_LLAMA3); - - int timeHour = 6; - boolean isNightTime = false; - - String prompt = "The Sun is shining, and its " + timeHour + ". Its daytime."; + String prompt = "The sun is shining brightly and is directly overhead at the zenith, casting my shadow over my foot, so it must be noon."; Map format = new HashMap<>(); format.put("type", "object"); format.put("properties", new HashMap() { { - put("timeHour", new HashMap() { - { - put("type", "integer"); - } - }); - put("isNightTime", new HashMap() { + put("isNoon", new HashMap() { { put("type", "boolean"); } }); } }); - format.put("required", Arrays.asList("timeHour", "isNightTime")); + format.put("required", List.of("isNoon")); - OllamaResult result = api.generate(CHAT_MODEL_LLAMA3, prompt, format); + OllamaResult result = api.generate(model, prompt, format); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); - assertEquals(timeHour, - result.getStructuredResponse().get("timeHour")); - assertEquals(isNightTime, - result.getStructuredResponse().get("isNightTime")); - - TimeOfDay timeOfDay = result.as(TimeOfDay.class); - - assertEquals(timeHour, timeOfDay.getTimeHour()); - assertEquals(isNightTime, timeOfDay.isNightTime()); + assertEquals(true, result.getStructuredResponse().get("isNoon")); } } diff --git a/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java b/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java index 8202e77..243a9fe 100644 --- a/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java +++ b/src/test/java/io/github/ollama4j/samples/AnnotatedTool.java @@ -8,14 +8,14 @@ import java.math.BigDecimal; public class AnnotatedTool { @ToolSpec(desc = "Computes the most important constant all around the globe!") - public String computeImportantConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){ - return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString(); + public String computeImportantConstant(@ToolProperty(name = "noOfDigits", desc = "Number of digits that shall be returned") Integer noOfDigits) { + return BigDecimal.valueOf((long) (Math.random() * 1000000L), noOfDigits).toString(); } @ToolSpec(desc = "Says hello to a friend!") - public String sayHello(@ToolProperty(name = "name",desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name="amountOfHearts",desc = "amount of heart emojis that should be used", required = false) Integer amountOfHearts) { - String hearts = amountOfHearts!=null ? "♡".repeat(amountOfHearts) : ""; - return "Hello " + name +" ("+someRandomProperty+") " + hearts; + public String sayHello(@ToolProperty(name = "name", desc = "Name of the friend") String name, @ToolProperty(name = "numberOfHearts", desc = "number of heart emojis that should be used", required = false) Integer numberOfHearts) { + String hearts = numberOfHearts != null ? "♡".repeat(numberOfHearts) : ""; + return "Hello, " + name + "! " + hearts; } } diff --git a/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java index 8499cd8..f95a2dc 100644 --- a/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java +++ b/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java @@ -21,7 +21,8 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.Mockito.*; class TestMockedAPIs { @@ -138,10 +139,10 @@ class TestMockedAPIs { String prompt = "some prompt text"; OptionsBuilder optionsBuilder = new OptionsBuilder(); try { - when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build())) - .thenReturn(new OllamaResult("", 0, 200)); - ollamaAPI.generate(model, prompt, false, optionsBuilder.build()); - verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build()); + when(ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build())) + .thenReturn(new OllamaResult("", "", 0, 200)); + ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build()); + verify(ollamaAPI, times(1)).generate(model, prompt, false, false, optionsBuilder.build()); } catch (IOException | OllamaBaseException | InterruptedException e) { throw new RuntimeException(e); } @@ -155,7 +156,7 @@ class TestMockedAPIs { try { when(ollamaAPI.generateWithImageFiles( model, prompt, Collections.emptyList(), new OptionsBuilder().build())) - .thenReturn(new OllamaResult("", 0, 200)); + .thenReturn(new OllamaResult("", "", 0, 200)); ollamaAPI.generateWithImageFiles( model, prompt, Collections.emptyList(), new OptionsBuilder().build()); verify(ollamaAPI, times(1)) @@ -174,7 +175,7 @@ class TestMockedAPIs { try { when(ollamaAPI.generateWithImageURLs( model, prompt, Collections.emptyList(), new OptionsBuilder().build())) - .thenReturn(new OllamaResult("", 0, 200)); + .thenReturn(new OllamaResult("", "", 0, 200)); ollamaAPI.generateWithImageURLs( model, prompt, Collections.emptyList(), new OptionsBuilder().build()); verify(ollamaAPI, times(1)) @@ -190,10 +191,10 @@ class TestMockedAPIs { OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); String model = OllamaModelType.LLAMA2; String prompt = "some prompt text"; - when(ollamaAPI.generateAsync(model, prompt, false)) + when(ollamaAPI.generateAsync(model, prompt, false, false)) .thenReturn(new OllamaAsyncResultStreamer(null, null, 3)); - ollamaAPI.generateAsync(model, prompt, false); - verify(ollamaAPI, times(1)).generateAsync(model, prompt, false); + ollamaAPI.generateAsync(model, prompt, false, false); + verify(ollamaAPI, times(1)).generateAsync(model, prompt, false, false); } @Test diff --git a/src/test/java/io/github/ollama4j/unittests/jackson/AbstractSerializationTest.java b/src/test/java/io/github/ollama4j/unittests/jackson/AbstractSerializationTest.java index 6e03566..09a5d67 100644 --- a/src/test/java/io/github/ollama4j/unittests/jackson/AbstractSerializationTest.java +++ b/src/test/java/io/github/ollama4j/unittests/jackson/AbstractSerializationTest.java @@ -1,11 +1,12 @@ package io.github.ollama4j.unittests.jackson; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import io.github.ollama4j.utils.Utils; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + public abstract class AbstractSerializationTest { protected ObjectMapper mapper = Utils.getObjectMapper(); diff --git a/src/test/java/io/github/ollama4j/unittests/jackson/TestChatRequestSerialization.java b/src/test/java/io/github/ollama4j/unittests/jackson/TestChatRequestSerialization.java index db33889..003538e 100644 --- a/src/test/java/io/github/ollama4j/unittests/jackson/TestChatRequestSerialization.java +++ b/src/test/java/io/github/ollama4j/unittests/jackson/TestChatRequestSerialization.java @@ -1,20 +1,19 @@ package io.github.ollama4j.unittests.jackson; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrowsExactly; +import io.github.ollama4j.models.chat.OllamaChatMessageRole; +import io.github.ollama4j.models.chat.OllamaChatRequest; +import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; +import io.github.ollama4j.utils.OptionsBuilder; +import org.json.JSONObject; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import java.io.File; import java.util.Collections; import java.util.List; -import io.github.ollama4j.models.chat.OllamaChatRequest; -import org.json.JSONObject; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import io.github.ollama4j.models.chat.OllamaChatMessageRole; -import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; -import io.github.ollama4j.utils.OptionsBuilder; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrowsExactly; public class TestChatRequestSerialization extends AbstractSerializationTest { diff --git a/src/test/java/io/github/ollama4j/unittests/jackson/TestEmbedRequestSerialization.java b/src/test/java/io/github/ollama4j/unittests/jackson/TestEmbedRequestSerialization.java index 534b204..fc5843e 100644 --- a/src/test/java/io/github/ollama4j/unittests/jackson/TestEmbedRequestSerialization.java +++ b/src/test/java/io/github/ollama4j/unittests/jackson/TestEmbedRequestSerialization.java @@ -1,12 +1,12 @@ package io.github.ollama4j.unittests.jackson; -import static org.junit.jupiter.api.Assertions.assertEquals; - import io.github.ollama4j.models.embeddings.OllamaEmbedRequestBuilder; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; +import io.github.ollama4j.utils.OptionsBuilder; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import io.github.ollama4j.utils.OptionsBuilder; + +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestEmbedRequestSerialization extends AbstractSerializationTest { diff --git a/src/test/java/io/github/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java b/src/test/java/io/github/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java index 4ca0672..bf9b970 100644 --- a/src/test/java/io/github/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java +++ b/src/test/java/io/github/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java @@ -1,15 +1,13 @@ package io.github.ollama4j.unittests.jackson; -import static org.junit.jupiter.api.Assertions.assertEquals; - import io.github.ollama4j.models.generate.OllamaGenerateRequest; +import io.github.ollama4j.models.generate.OllamaGenerateRequestBuilder; +import io.github.ollama4j.utils.OptionsBuilder; import org.json.JSONObject; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; - -import io.github.ollama4j.models.generate.OllamaGenerateRequestBuilder; -import io.github.ollama4j.utils.OptionsBuilder; +import static org.junit.jupiter.api.Assertions.assertEquals; public class TestGenerateRequestSerialization extends AbstractSerializationTest { diff --git a/src/test/java/io/github/ollama4j/unittests/jackson/TestModelRequestSerialization.java b/src/test/java/io/github/ollama4j/unittests/jackson/TestModelRequestSerialization.java index 5bc44f3..961dd43 100644 --- a/src/test/java/io/github/ollama4j/unittests/jackson/TestModelRequestSerialization.java +++ b/src/test/java/io/github/ollama4j/unittests/jackson/TestModelRequestSerialization.java @@ -3,40 +3,66 @@ package io.github.ollama4j.unittests.jackson; import io.github.ollama4j.models.response.Model; import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + public class TestModelRequestSerialization extends AbstractSerializationTest { @Test - public void testDeserializationOfModelResponseWithOffsetTime(){ - String serializedTestStringWithOffsetTime = "{\n" - + "\"name\": \"codellama:13b\",\n" - + "\"modified_at\": \"2023-11-04T14:56:49.277302595-07:00\",\n" - + "\"size\": 7365960935,\n" - + "\"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n" - + "\"details\": {\n" - + "\"format\": \"gguf\",\n" - + "\"family\": \"llama\",\n" - + "\"families\": null,\n" - + "\"parameter_size\": \"13B\",\n" - + "\"quantization_level\": \"Q4_0\"\n" - + "}}"; - deserialize(serializedTestStringWithOffsetTime,Model.class); + public void testDeserializationOfModelResponseWithOffsetTime() { + String serializedTestStringWithOffsetTime = "{\n" + + " \"name\": \"codellama:13b\",\n" + + " \"modified_at\": \"2023-11-04T14:56:49.277302595-07:00\",\n" + + " \"size\": 7365960935,\n" + + " \"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n" + + " \"details\": {\n" + + " \"format\": \"gguf\",\n" + + " \"family\": \"llama\",\n" + + " \"families\": null,\n" + + " \"parameter_size\": \"13B\",\n" + + " \"quantization_level\": \"Q4_0\"\n" + + " }\n" + + "}"; + Model model = deserialize(serializedTestStringWithOffsetTime, Model.class); + assertNotNull(model); + assertEquals("codellama:13b", model.getName()); + assertEquals("2023-11-04T21:56:49.277302595Z", model.getModifiedAt().toString()); + assertEquals(7365960935L, model.getSize()); + assertEquals("9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697", model.getDigest()); + assertNotNull(model.getModelMeta()); + assertEquals("gguf", model.getModelMeta().getFormat()); + assertEquals("llama", model.getModelMeta().getFamily()); + assertNull(model.getModelMeta().getFamilies()); + assertEquals("13B", model.getModelMeta().getParameterSize()); + assertEquals("Q4_0", model.getModelMeta().getQuantizationLevel()); } @Test - public void testDeserializationOfModelResponseWithZuluTime(){ - String serializedTestStringWithZuluTimezone = "{\n" - + "\"name\": \"codellama:13b\",\n" - + "\"modified_at\": \"2023-11-04T14:56:49.277302595Z\",\n" - + "\"size\": 7365960935,\n" - + "\"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n" - + "\"details\": {\n" - + "\"format\": \"gguf\",\n" - + "\"family\": \"llama\",\n" - + "\"families\": null,\n" - + "\"parameter_size\": \"13B\",\n" - + "\"quantization_level\": \"Q4_0\"\n" - + "}}"; - deserialize(serializedTestStringWithZuluTimezone,Model.class); + public void testDeserializationOfModelResponseWithZuluTime() { + String serializedTestStringWithZuluTimezone = "{\n" + + " \"name\": \"codellama:13b\",\n" + + " \"modified_at\": \"2023-11-04T14:56:49.277302595Z\",\n" + + " \"size\": 7365960935,\n" + + " \"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n" + + " \"details\": {\n" + + " \"format\": \"gguf\",\n" + + " \"family\": \"llama\",\n" + + " \"families\": null,\n" + + " \"parameter_size\": \"13B\",\n" + + " \"quantization_level\": \"Q4_0\"\n" + + " }\n" + + "}"; + Model model = deserialize(serializedTestStringWithZuluTimezone, Model.class); + assertNotNull(model); + assertEquals("codellama:13b", model.getName()); + assertEquals("2023-11-04T14:56:49.277302595Z", model.getModifiedAt().toString()); + assertEquals(7365960935L, model.getSize()); + assertEquals("9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697", model.getDigest()); + assertNotNull(model.getModelMeta()); + assertEquals("gguf", model.getModelMeta().getFormat()); + assertEquals("llama", model.getModelMeta().getFamily()); + assertNull(model.getModelMeta().getFamilies()); + assertEquals("13B", model.getModelMeta().getParameterSize()); + assertEquals("Q4_0", model.getModelMeta().getQuantizationLevel()); } } diff --git a/src/test/resources/roses.jpg b/src/test/resources/roses.jpg new file mode 100644 index 0000000..94aa6ca Binary files /dev/null and b/src/test/resources/roses.jpg differ