From b9b18271a1bf3025fc12f53c85fb5b71a9eaed30 Mon Sep 17 00:00:00 2001 From: amithkoujalgi Date: Mon, 24 Mar 2025 00:25:20 +0530 Subject: [PATCH] Support for structured output Added support for structured output --- Makefile | 8 +- README.md | 3 + docs/docs/apis-generate/generate.md | 128 +- .../java/io/github/ollama4j/OllamaAPI.java | 546 +++++--- .../models/response/OllamaResult.java | 80 +- .../response/OllamaStructuredResult.java | 77 ++ .../OllamaAPIIntegrationTest.java | 1177 +++++++++-------- 7 files changed, 1304 insertions(+), 715 deletions(-) create mode 100644 src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java diff --git a/Makefile b/Makefile index 34972dd..b5af010 100644 --- a/Makefile +++ b/Makefile @@ -7,13 +7,19 @@ dev: pre-commit install --install-hooks build: + mvn -B clean install -Dgpg.skip=true + +full-build: mvn -B clean install unit-tests: mvn clean test -Punit-tests integration-tests: - mvn clean verify -Pintegration-tests + export USE_EXTERNAL_OLLAMA_HOST=false && mvn clean verify -Pintegration-tests + +integration-tests-local: + export USE_EXTERNAL_OLLAMA_HOST=true && export OLLAMA_HOST=http://localhost:11434 && mvn clean verify -Pintegration-tests -Dgpg.skip=true doxygen: doxygen Doxyfile diff --git a/README.md b/README.md index 33ac281..b405125 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,9 @@ pip install pre-commit #### Setup dev environment +> **Note** +> If you're on Windows, install [Chocolatey Package Manager for Windows](https://chocolatey.org/install) and then install `make` by running `choco install make`. Just a little tip - run the command with administrator privileges if installation faiils. + ```shell make dev ``` diff --git a/docs/docs/apis-generate/generate.md b/docs/docs/apis-generate/generate.md index 1cd6a47..ee428f8 100644 --- a/docs/docs/apis-generate/generate.md +++ b/docs/docs/apis-generate/generate.md @@ -13,7 +13,7 @@ with [extra parameters](https://github.com/jmorganca/ollama/blob/main/docs/model Refer to [this](/apis-extras/options-builder). -## Try asking a question about the model. +## Try asking a question about the model ```java import io.github.ollama4j.OllamaAPI; @@ -87,7 +87,7 @@ You will get a response similar to: > The capital of France is Paris. > Full response: The capital of France is Paris. -## Try asking a question from general topics. +## Try asking a question from general topics ```java import io.github.ollama4j.OllamaAPI; @@ -135,7 +135,7 @@ You'd then get a response from the model: > semi-finals. The tournament was > won by the England cricket team, who defeated New Zealand in the final. -## Try asking for a Database query for your data schema. +## Try asking for a Database query for your data schema ```java import io.github.ollama4j.OllamaAPI; @@ -161,6 +161,7 @@ public class Main { ``` + _Note: Here I've used a [sample prompt](https://github.com/ollama4j/ollama4j/blob/main/src/main/resources/sample-db-prompt-template.txt) containing a database schema from within this library for demonstration purposes._ @@ -172,4 +173,125 @@ SELECT customers.name FROM sales JOIN customers ON sales.customer_id = customers.customer_id GROUP BY customers.name; +``` + + +## Generate structured output + +### With response as a `Map` + +```java +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import io.github.ollama4j.OllamaAPI; +import io.github.ollama4j.utils.Utilities; +import io.github.ollama4j.models.chat.OllamaChatMessageRole; +import io.github.ollama4j.models.chat.OllamaChatRequest; +import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; +import io.github.ollama4j.models.chat.OllamaChatResult; +import io.github.ollama4j.models.response.OllamaResult; +import io.github.ollama4j.types.OllamaModelType; + +public class StructuredOutput { + + public static void main(String[] args) throws Exception { + String host = "http://localhost:11434/"; + + OllamaAPI api = new OllamaAPI(host); + + String chatModel = "qwen2.5:0.5b"; + api.pullModel(chatModel); + + String prompt = "Ollama is 22 years old and is busy saving the world. Respond using JSON"; + Map format = new HashMap<>(); + format.put("type", "object"); + format.put("properties", new HashMap() { + { + put("age", new HashMap() { + { + put("type", "integer"); + } + }); + put("available", new HashMap() { + { + put("type", "boolean"); + } + }); + } + }); + format.put("required", Arrays.asList("age", "available")); + + OllamaResult result = api.generate(chatModel, prompt, format); + System.out.println(result); + } +} +``` + +### With response mapped to specified class type + +```java +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import io.github.ollama4j.OllamaAPI; +import io.github.ollama4j.utils.Utilities; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import io.github.ollama4j.models.chat.OllamaChatMessageRole; +import io.github.ollama4j.models.chat.OllamaChatRequest; +import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; +import io.github.ollama4j.models.chat.OllamaChatResult; +import io.github.ollama4j.models.response.OllamaResult; +import io.github.ollama4j.types.OllamaModelType; + +public class StructuredOutput { + + public static void main(String[] args) throws Exception { + String host = Utilities.getFromConfig("host"); + + OllamaAPI api = new OllamaAPI(host); + + int age = 28; + boolean available = false; + + String prompt = "Batman is " + age + " years old and is " + (available ? "available" : "not available") + + " because he is busy saving Gotham City. Respond using JSON"; + + Map format = new HashMap<>(); + format.put("type", "object"); + format.put("properties", new HashMap() { + { + put("age", new HashMap() { + { + put("type", "integer"); + } + }); + put("available", new HashMap() { + { + put("type", "boolean"); + } + }); + } + }); + format.put("required", Arrays.asList("age", "available")); + + OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, prompt, format); + + Person person = result.as(Person.class); + System.out.println(person.getAge()); + System.out.println(person.getAvailable()); + } +} + +@Data +@AllArgsConstructor +@NoArgsConstructor +class Person { + private int age; + private boolean available; +} ``` \ No newline at end of file diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index b0d21c4..1c0abd5 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -51,7 +51,7 @@ import java.util.stream.Collectors; /** * The base Ollama API class. */ -@SuppressWarnings({"DuplicatedCode", "resource"}) +@SuppressWarnings({ "DuplicatedCode", "resource" }) public class OllamaAPI { private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); @@ -74,10 +74,17 @@ public class OllamaAPI { private Auth auth; + private int numberOfRetriesForModelPull = 0; + + public void setNumberOfRetriesForModelPull(int numberOfRetriesForModelPull) { + this.numberOfRetriesForModelPull = numberOfRetriesForModelPull; + } + private final ToolRegistry toolRegistry = new ToolRegistry(); /** - * Instantiates the Ollama API with default Ollama host: http://localhost:11434 + * Instantiates the Ollama API with default Ollama host: + * http://localhost:11434 **/ public OllamaAPI() { this.host = "http://localhost:11434"; @@ -100,7 +107,8 @@ public class OllamaAPI { } /** - * Set basic authentication for accessing Ollama server that's behind a reverse-proxy/gateway. + * Set basic authentication for accessing Ollama server that's behind a + * reverse-proxy/gateway. * * @param username the username * @param password the password @@ -110,7 +118,8 @@ public class OllamaAPI { } /** - * Set Bearer authentication for accessing Ollama server that's behind a reverse-proxy/gateway. + * Set Bearer authentication for accessing Ollama server that's behind a + * reverse-proxy/gateway. * * @param bearerToken the Bearer authentication token to provide */ @@ -128,7 +137,8 @@ public class OllamaAPI { HttpClient httpClient = HttpClient.newHttpClient(); HttpRequest httpRequest = null; try { - httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); + httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") + .header("Content-type", "application/json").GET().build(); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -145,7 +155,8 @@ public class OllamaAPI { } /** - * Provides a list of running models and details about each model currently loaded into memory. + * Provides a list of running models and details about each model currently + * loaded into memory. * * @return ModelsProcessResponse containing details about the running models * @throws IOException if an I/O error occurs during the HTTP request @@ -157,7 +168,8 @@ public class OllamaAPI { HttpClient httpClient = HttpClient.newHttpClient(); HttpRequest httpRequest = null; try { - httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); + httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") + .header("Content-type", "application/json").GET().build(); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -184,7 +196,8 @@ public class OllamaAPI { public List listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = this.host + "/api/tags"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") + .header("Content-type", "application/json").GET().build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -196,19 +209,28 @@ public class OllamaAPI { } /** - * Retrieves a list of models from the Ollama library. This method fetches the available models directly from Ollama - * library page, including model details such as the name, pull count, popular tags, tag count, and the time when model was updated. + * Retrieves a list of models from the Ollama library. This method fetches the + * available models directly from Ollama + * library page, including model details such as the name, pull count, popular + * tags, tag count, and the time when model was updated. * - * @return A list of {@link LibraryModel} objects representing the models available in the Ollama library. - * @throws OllamaBaseException If the HTTP request fails or the response is not successful (non-200 status code). - * @throws IOException If an I/O error occurs during the HTTP request or response processing. - * @throws InterruptedException If the thread executing the request is interrupted. - * @throws URISyntaxException If there is an error creating the URI for the HTTP request. + * @return A list of {@link LibraryModel} objects representing the models + * available in the Ollama library. + * @throws OllamaBaseException If the HTTP request fails or the response is not + * successful (non-200 status code). + * @throws IOException If an I/O error occurs during the HTTP request + * or response processing. + * @throws InterruptedException If the thread executing the request is + * interrupted. + * @throws URISyntaxException If there is an error creating the URI for the + * HTTP request. */ - public List listModelsFromLibrary() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + public List listModelsFromLibrary() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = "https://ollama.com/library"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") + .header("Content-type", "application/json").GET().build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -223,7 +245,8 @@ public class OllamaAPI { Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type"); Elements popularTags = e.select("div > div > span"); Elements totalTags = e.select("div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type"); - Elements lastUpdatedTime = e.select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)"); + Elements lastUpdatedTime = e + .select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)"); if (names.first() == null || names.isEmpty()) { // if name cannot be extracted, skip. @@ -231,9 +254,12 @@ public class OllamaAPI { } Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName); model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse("")); - model.setPopularTags(Optional.of(popularTags).map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())).orElse(new ArrayList<>())); + model.setPopularTags(Optional.of(popularTags) + .map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())) + .orElse(new ArrayList<>())); model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse("")); - model.setTotalTags(Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0)); + model.setTotalTags( + Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0)); model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse("")); models.add(model); @@ -246,22 +272,32 @@ public class OllamaAPI { /** * Fetches the tags associated with a specific model from Ollama library. - * This method fetches the available model tags directly from Ollama library model page, including model tag name, size and time when model was last updated + * This method fetches the available model tags directly from Ollama library + * model page, including model tag name, size and time when model was last + * updated * into a list of {@link LibraryModelTag} objects. * - * @param libraryModel the {@link LibraryModel} object which contains the name of the library model + * @param libraryModel the {@link LibraryModel} object which contains the name + * of the library model * for which the tags need to be fetched. - * @return a list of {@link LibraryModelTag} objects containing the extracted tags and their associated metadata. - * @throws OllamaBaseException if the HTTP response status code indicates an error (i.e., not 200 OK), - * or if there is any other issue during the request or response processing. - * @throws IOException if an input/output exception occurs during the HTTP request or response handling. - * @throws InterruptedException if the thread is interrupted while waiting for the HTTP response. + * @return a list of {@link LibraryModelTag} objects containing the extracted + * tags and their associated metadata. + * @throws OllamaBaseException if the HTTP response status code indicates an + * error (i.e., not 200 OK), + * or if there is any other issue during the + * request or response processing. + * @throws IOException if an input/output exception occurs during the + * HTTP request or response handling. + * @throws InterruptedException if the thread is interrupted while waiting for + * the HTTP response. * @throws URISyntaxException if the URI format is incorrect or invalid. */ - public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel) + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName()); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") + .header("Content-type", "application/json").GET().build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -269,7 +305,8 @@ public class OllamaAPI { List libraryModelTags = new ArrayList<>(); if (statusCode == 200) { Document doc = Jsoup.parse(responseString); - Elements tagSections = doc.select("html > body > main > div > section > div > div > div:nth-child(n+2) > div"); + Elements tagSections = doc + .select("html > body > main > div > section > div > div > div:nth-child(n+2) > div"); for (Element e : tagSections) { Elements tags = e.select("div > a > div"); Elements tagsMetas = e.select("div > span"); @@ -282,8 +319,11 @@ public class OllamaAPI { } libraryModelTag.setName(libraryModel.getName()); Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag); - libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")).filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse("")); - libraryModelTag.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")).filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse("")); + libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")) + .filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse("")); + libraryModelTag + .setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")) + .filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse("")); libraryModelTags.add(libraryModelTag); } LibraryModelDetail libraryModelDetail = new LibraryModelDetail(); @@ -298,24 +338,35 @@ public class OllamaAPI { /** * Finds a specific model using model name and tag from Ollama library. *

- * This method retrieves the model from the Ollama library by its name, then fetches its tags. - * It searches through the tags of the model to find one that matches the specified tag name. - * If the model or the tag is not found, it throws a {@link NoSuchElementException}. + * This method retrieves the model from the Ollama library by its name, then + * fetches its tags. + * It searches through the tags of the model to find one that matches the + * specified tag name. + * If the model or the tag is not found, it throws a + * {@link NoSuchElementException}. * * @param modelName The name of the model to search for in the library. * @param tag The tag name to search for within the specified model. - * @return The {@link LibraryModelTag} associated with the specified model and tag. - * @throws OllamaBaseException If there is a problem with the Ollama library operations. + * @return The {@link LibraryModelTag} associated with the specified model and + * tag. + * @throws OllamaBaseException If there is a problem with the Ollama library + * operations. * @throws IOException If an I/O error occurs during the operation. * @throws URISyntaxException If there is an error with the URI syntax. * @throws InterruptedException If the operation is interrupted. * @throws NoSuchElementException If the model or the tag is not found. */ - public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { List libraryModels = this.listModelsFromLibrary(); - LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName))); + LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)) + .findFirst().orElseThrow( + () -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName))); LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel); - LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Tag '%s' for model '%s' not found", tag, modelName))); + LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream() + .filter(tagName -> tagName.getTag().equals(tag)).findFirst() + .orElseThrow(() -> new NoSuchElementException( + String.format("Tag '%s' for model '%s' not found", tag, modelName))); return libraryModelTag; } @@ -329,7 +380,28 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public void pullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + public void pullModel(String modelName) + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + if (numberOfRetriesForModelPull == 0) { + this.doPullModel(modelName); + } else { + int numberOfRetries = 0; + while (numberOfRetries < numberOfRetriesForModelPull) { + try { + this.doPullModel(modelName); + return; + } catch (OllamaBaseException e) { + logger.error("Failed to pull model " + modelName + ", retrying..."); + numberOfRetries++; + } + } + throw new OllamaBaseException( + "Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries"); + } + } + + private void doPullModel(String modelName) + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String url = this.host + "/api/pull"; String jsonData = new ModelRequest(modelName).toString(); HttpRequest request = getRequestBuilderDefault(new URI(url)) @@ -343,7 +415,8 @@ public class OllamaAPI { InputStream responseBodyStream = response.body(); String responseString = ""; boolean success = false; // Flag to check the pull success. - try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class); @@ -369,11 +442,11 @@ public class OllamaAPI { } } - public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException { String url = this.host + "/api/version"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") + .header("Content-type", "application/json").GET().build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -386,17 +459,20 @@ public class OllamaAPI { /** * Pulls a model using the specified Ollama library model tag. - * The model is identified by a name and a tag, which are combined into a single identifier + * The model is identified by a name and a tag, which are combined into a single + * identifier * in the format "name:tag" to pull the corresponding model. * - * @param libraryModelTag the {@link LibraryModelTag} object containing the name and tag + * @param libraryModelTag the {@link LibraryModelTag} object containing the name + * and tag * of the model to be pulled. * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public void pullModel(LibraryModelTag libraryModelTag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + public void pullModel(LibraryModelTag libraryModelTag) + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String tagToPull = String.format("%s:%s", libraryModelTag.getName(), libraryModelTag.getTag()); pullModel(tagToPull); } @@ -411,10 +487,12 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException, URISyntaxException { + public ModelDetail getModelDetails(String modelName) + throws IOException, OllamaBaseException, InterruptedException, URISyntaxException { String url = this.host + "/api/show"; String jsonData = new ModelRequest(modelName).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") + .header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -427,8 +505,10 @@ public class OllamaAPI { } /** - * Create a custom model from a model file. Read more about custom model file creation here. + * Create a custom model from a model file. Read more about custom model file + * creation here. * * @param modelName the name of the custom model to be created. * @param modelFilePath the path to model file that exists on the Ollama server. @@ -438,10 +518,13 @@ public class OllamaAPI { * @throws URISyntaxException if the URI for the request is malformed */ @Deprecated - public void createModelWithFilePath(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { + public void createModelWithFilePath(String modelName, String modelFilePath) + throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -449,7 +532,8 @@ public class OllamaAPI { if (statusCode != 200) { throw new OllamaBaseException(statusCode + " - " + responseString); } - // FIXME: Ollama API returns HTTP status code 200 for model creation failure cases. Correct this + // FIXME: Ollama API returns HTTP status code 200 for model creation failure + // cases. Correct this // if the issue is fixed in the Ollama API server. if (responseString.contains("error")) { throw new OllamaBaseException(responseString); @@ -460,21 +544,27 @@ public class OllamaAPI { } /** - * Create a custom model from a model file. Read more about custom model file creation here. + * Create a custom model from a model file. Read more about custom model file + * creation here. * * @param modelName the name of the custom model to be created. - * @param modelFileContents the path to model file that exists on the Ollama server. + * @param modelFileContents the path to model file that exists on the Ollama + * server. * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ @Deprecated - public void createModelWithModelFileContents(String modelName, String modelFileContents) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { + public void createModelWithModelFileContents(String modelName, String modelFileContents) + throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -492,7 +582,8 @@ public class OllamaAPI { /** * Create a custom model. Read more about custom model creation here. + * href= + * "https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model">here. * * @param customModelRequest custom model spec * @throws OllamaBaseException if the response indicates an error status @@ -500,10 +591,13 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public void createModel(CustomModelRequest customModelRequest) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { + public void createModel(CustomModelRequest customModelRequest) + throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = customModelRequest.toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -523,16 +617,20 @@ public class OllamaAPI { * Delete a model from Ollama server. * * @param modelName the name of the model to be deleted. - * @param ignoreIfNotPresent ignore errors if the specified model is not present on Ollama server. + * @param ignoreIfNotPresent ignore errors if the specified model is not present + * on Ollama server. * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public void deleteModel(String modelName, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { + public void deleteModel(String modelName, boolean ignoreIfNotPresent) + throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/delete"; String jsonData = new ModelRequest(modelName).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)) + .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) + .header("Accept", "application/json").header("Content-type", "application/json").build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -557,7 +655,8 @@ public class OllamaAPI { * @deprecated Use {@link #embed(String, List)} instead. */ @Deprecated - public List generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException { + public List generateEmbeddings(String model, String prompt) + throws IOException, InterruptedException, OllamaBaseException { return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt)); } @@ -572,17 +671,20 @@ public class OllamaAPI { * @deprecated Use {@link #embed(OllamaEmbedRequestModel)} instead. */ @Deprecated - public List generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException { + public List generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) + throws IOException, InterruptedException, OllamaBaseException { URI uri = URI.create(this.host + "/api/embeddings"); String jsonData = modelRequest.toString(); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)); + HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(jsonData)); HttpRequest request = requestBuilder.build(); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseBody = response.body(); if (statusCode == 200) { - OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class); + OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody, + OllamaEmbeddingResponseModel.class); return embeddingResponse.getEmbedding(); } else { throw new OllamaBaseException(statusCode + " - " + responseBody); @@ -599,7 +701,8 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaEmbedResponseModel embed(String model, List inputs) throws IOException, InterruptedException, OllamaBaseException { + public OllamaEmbedResponseModel embed(String model, List inputs) + throws IOException, InterruptedException, OllamaBaseException { return embed(new OllamaEmbedRequestModel(model, inputs)); } @@ -612,12 +715,14 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException { + public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) + throws IOException, InterruptedException, OllamaBaseException { URI uri = URI.create(this.host + "/api/embed"); String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -631,21 +736,26 @@ public class OllamaAPI { } /** - * Generate response for a question to a model running on Ollama server. This is a sync/blocking + * Generate response for a question to a model running on Ollama server. This is + * a sync/blocking * call. * * @param model the ollama model to ask the question to * @param prompt the prompt/question text * @param options the Options object - More + * href= + * "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More * details on the options - * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false. + * @param streamHandler optional callback consumer that will be applied every + * time a streamed response is received. If not set, the + * stream parameter of the request is set to false. * @return OllamaResult that includes response text and time taken for response * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generate(String model, String prompt, boolean raw, Options options, + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); ollamaRequestModel.setRaw(raw); ollamaRequestModel.setOptions(options.getOptionsMap()); @@ -653,36 +763,96 @@ public class OllamaAPI { } /** - * Generates response using the specified AI model and prompt (in blocking mode). + * Generates structured output from the specified AI model and prompt. + * + * @param model The name or identifier of the AI model to use for generating + * the response. + * @param prompt The input text or prompt to provide to the AI model. + * @param format A map containing the format specification for the structured + * output. + * @return An instance of {@link OllamaResult} containing the structured + * response. + * @throws OllamaBaseException if the response indicates an error status. + * @throws IOException if an I/O error occurs during the HTTP request. + * @throws InterruptedException if the operation is interrupted. + */ + public OllamaResult generate(String model, String prompt, Map format) + throws OllamaBaseException, IOException, InterruptedException { + URI uri = URI.create(this.host + "/api/generate"); + + Map requestBody = new HashMap<>(); + requestBody.put("model", model); + requestBody.put("prompt", prompt); + requestBody.put("stream", false); + requestBody.put("format", format); + + String jsonData = Utils.getObjectMapper().writeValueAsString(requestBody); + HttpClient httpClient = HttpClient.newHttpClient(); + + HttpRequest request = HttpRequest.newBuilder(uri) + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(jsonData)) + .build(); + + HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + int statusCode = response.statusCode(); + String responseBody = response.body(); + + if (statusCode == 200) { + OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, + OllamaStructuredResult.class); + OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), + structuredResult.getResponseTime(), statusCode); + return ollamaResult; + } else { + throw new OllamaBaseException(statusCode + " - " + responseBody); + } + } + + /** + * Generates response using the specified AI model and prompt (in blocking + * mode). *

* Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)} * - * @param model The name or identifier of the AI model to use for generating the response. + * @param model The name or identifier of the AI model to use for generating + * the response. * @param prompt The input text or prompt to provide to the AI model. - * @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context. - * @param options Additional options or configurations to use when generating the response. + * @param raw In some cases, you may wish to bypass the templating system + * and provide a full prompt. In this case, you can use the raw + * parameter to disable templating. Also note that raw mode will + * not return a context. + * @param options Additional options or configurations to use when generating + * the response. * @return {@link OllamaResult} * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generate(String model, String prompt, boolean raw, Options options) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generate(String model, String prompt, boolean raw, Options options) + throws OllamaBaseException, IOException, InterruptedException { return generate(model, prompt, raw, options, null); } /** - * Generates response using the specified AI model and prompt (in blocking mode), and then invokes a set of tools + * Generates response using the specified AI model and prompt (in blocking + * mode), and then invokes a set of tools * on the generated response. * - * @param model The name or identifier of the AI model to use for generating the response. + * @param model The name or identifier of the AI model to use for generating + * the response. * @param prompt The input text or prompt to provide to the AI model. - * @param options Additional options or configurations to use when generating the response. - * @return {@link OllamaToolsResult} An OllamaToolsResult object containing the response from the AI model and the results of invoking the tools on that output. + * @param options Additional options or configurations to use when generating + * the response. + * @return {@link OllamaToolsResult} An OllamaToolsResult object containing the + * response from the AI model and the results of invoking the tools on + * that output. * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaToolsResult generateWithTools(String model, String prompt, Options options) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + public OllamaToolsResult generateWithTools(String model, String prompt, Options options) + throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { boolean raw = true; OllamaToolsResult toolResult = new OllamaToolsResult(); Map toolResults = new HashMap<>(); @@ -717,8 +887,7 @@ public class OllamaAPI { } toolFunctionCallSpecs = objectMapper.readValue( toolsResponse, - objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class) - ); + objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class)); } for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) { toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec)); @@ -728,8 +897,10 @@ public class OllamaAPI { } /** - * Generate response for a question to a model running on Ollama server and get a callback handle - * that can be used to check for status and get the response from the model later. This would be + * Generate response for a question to a model running on Ollama server and get + * a callback handle + * that can be used to check for status and get the response from the model + * later. This would be * an async/non-blocking call. * * @param model the ollama model to ask the question to @@ -740,28 +911,34 @@ public class OllamaAPI { OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); ollamaRequestModel.setRaw(raw); URI uri = URI.create(this.host + "/api/generate"); - OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds); + OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer( + getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds); ollamaAsyncResultStreamer.start(); return ollamaAsyncResultStreamer; } /** - * With one or more image files, ask a question to a model running on Ollama server. This is a + * With one or more image files, ask a question to a model running on Ollama + * server. This is a * sync/blocking call. * * @param model the ollama model to ask the question to * @param prompt the prompt/question text * @param imageFiles the list of image files to use for the question * @param options the Options object - More + * href= + * "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More * details on the options - * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false. + * @param streamHandler optional callback consumer that will be applied every + * time a streamed response is received. If not set, the + * stream parameter of the request is set to false. * @return OllamaResult that includes response text and time taken for response * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options, + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { List images = new ArrayList<>(); for (File imageFile : imageFiles) { images.add(encodeFileToBase64(imageFile)); @@ -774,34 +951,42 @@ public class OllamaAPI { /** * Convenience method to call Ollama API without streaming responses. *

- * Uses {@link #generateWithImageFiles(String, String, List, Options, OllamaStreamHandler)} + * Uses + * {@link #generateWithImageFiles(String, String, List, Options, OllamaStreamHandler)} * * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options) + throws OllamaBaseException, IOException, InterruptedException { return generateWithImageFiles(model, prompt, imageFiles, options, null); } /** - * With one or more image URLs, ask a question to a model running on Ollama server. This is a + * With one or more image URLs, ask a question to a model running on Ollama + * server. This is a * sync/blocking call. * * @param model the ollama model to ask the question to * @param prompt the prompt/question text * @param imageURLs the list of image URLs to use for the question * @param options the Options object - More + * href= + * "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">More * details on the options - * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false. + * @param streamHandler optional callback consumer that will be applied every + * time a streamed response is received. If not set, the + * stream parameter of the request is set to false. * @return OllamaResult that includes response text and time taken for response * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options, + OllamaStreamHandler streamHandler) + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { List images = new ArrayList<>(); for (String imageURL : imageURLs) { images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL))); @@ -814,38 +999,45 @@ public class OllamaAPI { /** * Convenience method to call Ollama API without streaming responses. *

- * Uses {@link #generateWithImageURLs(String, String, List, Options, OllamaStreamHandler)} + * Uses + * {@link #generateWithImageURLs(String, String, List, Options, OllamaStreamHandler)} * * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options) + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { return generateWithImageURLs(model, prompt, imageURLs, options, null); } /** - * Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api + * Ask a question to a model based on a given message stack (i.e. a chat + * history). Creates a synchronous call to the api * 'api/chat'. * * @param model the ollama model to ask the question to * @param messages chat history / message stack to send to the model - * @return {@link OllamaChatResult} containing the api response and the message history including the newly aqcuired assistant response. + * @return {@link OllamaChatResult} containing the api response and the message + * history including the newly aqcuired assistant response. * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network issues happen + * @throws InterruptedException in case the server is not reachable or network + * issues happen * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaChatResult chat(String model, List messages) throws OllamaBaseException, IOException, InterruptedException { + public OllamaChatResult chat(String model, List messages) + throws OllamaBaseException, IOException, InterruptedException { OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model); return chat(builder.withMessages(messages).build()); } /** - * Ask a question to a model using an {@link OllamaChatRequest}. This can be constructed using an {@link OllamaChatRequestBuilder}. + * Ask a question to a model using an {@link OllamaChatRequest}. This can be + * constructed using an {@link OllamaChatRequestBuilder}. *

* Hint: the OllamaChatRequestModel#getStream() property is not implemented. * @@ -853,55 +1045,69 @@ public class OllamaAPI { * @return {@link OllamaChatResult} * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network issues happen + * @throws InterruptedException in case the server is not reachable or network + * issues happen * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaChatResult chat(OllamaChatRequest request) throws OllamaBaseException, IOException, InterruptedException { + public OllamaChatResult chat(OllamaChatRequest request) + throws OllamaBaseException, IOException, InterruptedException { return chat(request, null); } /** - * Ask a question to a model using an {@link OllamaChatRequest}. This can be constructed using an {@link OllamaChatRequestBuilder}. + * Ask a question to a model using an {@link OllamaChatRequest}. This can be + * constructed using an {@link OllamaChatRequestBuilder}. *

* Hint: the OllamaChatRequestModel#getStream() property is not implemented. * * @param request request object to be sent to the server - * @param streamHandler callback handler to handle the last message from stream (caution: all previous messages from stream will be concatenated) + * @param streamHandler callback handler to handle the last message from stream + * (caution: all previous messages from stream will be + * concatenated) * @return {@link OllamaChatResult} * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network issues happen + * @throws InterruptedException in case the server is not reachable or network + * issues happen * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) + throws OllamaBaseException, IOException, InterruptedException { return chatStreaming(request, new OllamaChatStreamObserver(streamHandler)); } /** - * Ask a question to a model using an {@link OllamaChatRequest}. This can be constructed using an {@link OllamaChatRequestBuilder}. + * Ask a question to a model using an {@link OllamaChatRequest}. This can be + * constructed using an {@link OllamaChatRequestBuilder}. *

* Hint: the OllamaChatRequestModel#getStream() property is not implemented. * * @param request request object to be sent to the server - * @param tokenHandler callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated) + * @param tokenHandler callback handler to handle the last token from stream + * (caution: all previous messages from stream will be + * concatenated) * @return {@link OllamaChatResult} * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network issues happen + * @throws InterruptedException in case the server is not reachable or network + * issues happen * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException { - OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, verbose); + public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) + throws OllamaBaseException, IOException, InterruptedException { + OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, + verbose); OllamaChatResult result; // add all registered tools to Request - request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList())); + request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt) + .collect(Collectors.toList())); if (tokenHandler != null) { request.setStream(true); @@ -919,7 +1125,8 @@ public class OllamaAPI { ToolFunction toolFunction = toolRegistry.getToolFunction(toolName); Map arguments = toolCall.getFunction().getArguments(); Object res = toolFunction.apply(arguments); - request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]")); + request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, + "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]")); } if (tokenHandler != null) { @@ -935,9 +1142,11 @@ public class OllamaAPI { } /** - * Registers a single tool in the tool registry using the provided tool specification. + * Registers a single tool in the tool registry using the provided tool + * specification. * - * @param toolSpecification the specification of the tool to register. It contains the + * @param toolSpecification the specification of the tool to register. It + * contains the * tool's function name and other relevant information. */ public void registerTool(Tools.ToolSpecification toolSpecification) { @@ -948,11 +1157,14 @@ public class OllamaAPI { } /** - * Registers multiple tools in the tool registry using a list of tool specifications. + * Registers multiple tools in the tool registry using a list of tool + * specifications. * Iterates over the list and adds each tool specification to the registry. * - * @param toolSpecifications a list of tool specifications to register. Each specification - * contains information about a tool, such as its function name. + * @param toolSpecifications a list of tool specifications to register. Each + * specification + * contains information about a tool, such as its + * function name. */ public void registerTools(List toolSpecifications) { for (Tools.ToolSpecification toolSpecification : toolSpecifications) { @@ -961,12 +1173,16 @@ public class OllamaAPI { } /** - * Registers tools based on the annotations found on the methods of the caller's class and its providers. - * This method scans the caller's class for the {@link OllamaToolService} annotation and recursively registers + * Registers tools based on the annotations found on the methods of the caller's + * class and its providers. + * This method scans the caller's class for the {@link OllamaToolService} + * annotation and recursively registers * annotated tools from all the providers specified in the annotation. * - * @throws IllegalStateException if the caller's class is not annotated with {@link OllamaToolService}. - * @throws RuntimeException if any reflection-based instantiation or invocation fails. + * @throws IllegalStateException if the caller's class is not annotated with + * {@link OllamaToolService}. + * @throws RuntimeException if any reflection-based instantiation or + * invocation fails. */ public void registerAnnotatedTools() { try { @@ -986,19 +1202,24 @@ public class OllamaAPI { for (Class provider : providers) { registerAnnotatedTools(provider.getDeclaredConstructor().newInstance()); } - } catch (InstantiationException | NoSuchMethodException | IllegalAccessException | - InvocationTargetException e) { + } catch (InstantiationException | NoSuchMethodException | IllegalAccessException + | InvocationTargetException e) { throw new RuntimeException(e); } } /** - * Registers tools based on the annotations found on the methods of the provided object. - * This method scans the methods of the given object and registers tools using the {@link ToolSpec} annotation - * and associated {@link ToolProperty} annotations. It constructs tool specifications and stores them in a tool registry. + * Registers tools based on the annotations found on the methods of the provided + * object. + * This method scans the methods of the given object and registers tools using + * the {@link ToolSpec} annotation + * and associated {@link ToolProperty} annotations. It constructs tool + * specifications and stores them in a tool registry. * - * @param object the object whose methods are to be inspected for annotated tools. - * @throws RuntimeException if any reflection-based instantiation or invocation fails. + * @param object the object whose methods are to be inspected for annotated + * tools. + * @throws RuntimeException if any reflection-based instantiation or invocation + * fails. */ public void registerAnnotatedTools(Object object) { Class objectClass = object.getClass(); @@ -1022,12 +1243,22 @@ public class OllamaAPI { } String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName(); methodParams.put(propName, propType); - propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType).description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build()); + propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType) + .description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build()); } final Map params = propsBuilder.build(); - List reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()).map(Map.Entry::getKey).collect(Collectors.toList()); + List reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()) + .map(Map.Entry::getKey).collect(Collectors.toList()); - Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName).functionDescription(operationDesc).toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName).description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(params).required(reqProps).build()).build()).build()).build(); + Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName) + .functionDescription(operationDesc) + .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") + .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName) + .description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters + .builder().type("object").properties(params).required(reqProps).build()) + .build()) + .build()) + .build(); ReflectionalToolFunction reflectionalToolFunction = new ReflectionalToolFunction(object, m, methodParams); toolSpecification.setToolFunction(reflectionalToolFunction); @@ -1060,13 +1291,13 @@ public class OllamaAPI { * * @param roleName the name of the role to retrieve * @return the OllamaChatMessageRole associated with the given name - * @throws RoleNotFoundException if the role with the specified name does not exist + * @throws RoleNotFoundException if the role with the specified name does not + * exist */ public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException { return OllamaChatMessageRole.getRole(roleName); } - // technical private methods // /** @@ -1092,18 +1323,26 @@ public class OllamaAPI { /** * Generates a request for the Ollama API and returns the result. - * This method synchronously calls the Ollama API. If a stream handler is provided, - * the request will be streamed; otherwise, a regular synchronous request will be made. + * This method synchronously calls the Ollama API. If a stream handler is + * provided, + * the request will be streamed; otherwise, a regular synchronous request will + * be made. * - * @param ollamaRequestModel the request model containing necessary parameters for the Ollama API request. - * @param streamHandler the stream handler to process streaming responses, or null for non-streaming requests. + * @param ollamaRequestModel the request model containing necessary parameters + * for the Ollama API request. + * @param streamHandler the stream handler to process streaming responses, + * or null for non-streaming requests. * @return the result of the Ollama API request. - * @throws OllamaBaseException if the request fails due to an issue with the Ollama API. - * @throws IOException if an I/O error occurs during the request process. + * @throws OllamaBaseException if the request fails due to an issue with the + * Ollama API. + * @throws IOException if an I/O error occurs during the request + * process. * @throws InterruptedException if the thread is interrupted during the request. */ - private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { - OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, verbose); + private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, + verbose); OllamaResult result; if (streamHandler != null) { ollamaRequestModel.setStream(true); @@ -1114,7 +1353,6 @@ public class OllamaAPI { return result; } - /** * Get default request builder. * @@ -1122,7 +1360,8 @@ public class OllamaAPI { * @return HttpRequest.Builder */ private HttpRequest.Builder getRequestBuilderDefault(URI uri) { - HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds)); + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(requestTimeoutSeconds)); if (isBasicAuthCredentialsSet()) { requestBuilder.header("Authorization", auth.getAuthHeaderValue()); } @@ -1147,7 +1386,8 @@ public class OllamaAPI { logger.debug("Invoking function {} with arguments {}", methodName, arguments); } if (function == null) { - throw new ToolNotFoundException("No such tool: " + methodName + ". Please register the tool before invoking it."); + throw new ToolNotFoundException( + "No such tool: " + methodName + ". Please register the tool before invoking it."); } return function.apply(arguments); } catch (Exception e) { diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java index beb01ec..4b538f9 100644 --- a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java +++ b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java @@ -1,19 +1,26 @@ package io.github.ollama4j.models.response; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; + +import lombok.Data; +import lombok.Getter; + import static io.github.ollama4j.utils.Utils.getObjectMapper; -import com.fasterxml.jackson.core.JsonProcessingException; -import lombok.Data; -import lombok.Getter; +import java.util.HashMap; +import java.util.Map; /** The type Ollama result. */ @Getter @SuppressWarnings("unused") @Data +@JsonIgnoreProperties(ignoreUnknown = true) public class OllamaResult { /** * -- GETTER -- - * Get the completion/response text + * Get the completion/response text * * @return String completion/response text */ @@ -21,7 +28,7 @@ public class OllamaResult { /** * -- GETTER -- - * Get the response status code. + * Get the response status code. * * @return int - response status code */ @@ -29,7 +36,7 @@ public class OllamaResult { /** * -- GETTER -- - * Get the response time in milliseconds. + * Get the response time in milliseconds. * * @return long - response time in milliseconds */ @@ -44,9 +51,68 @@ public class OllamaResult { @Override public String toString() { try { - return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + Map responseMap = new HashMap<>(); + responseMap.put("response", this.response); + responseMap.put("httpStatusCode", this.httpStatusCode); + responseMap.put("responseTime", this.responseTime); + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap); } catch (JsonProcessingException e) { throw new RuntimeException(e); } } + + /** + * Get the structured response if the response is a JSON object. + * + * @return Map - structured response + * @throws IllegalArgumentException if the response is not a valid JSON object + */ + public Map getStructuredResponse() { + String responseStr = this.getResponse(); + if (responseStr == null || responseStr.trim().isEmpty()) { + throw new IllegalArgumentException("Response is empty or null"); + } + + try { + // Check if the response is a valid JSON + if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || + (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { + throw new IllegalArgumentException("Response is not a valid JSON object"); + } + + Map response = getObjectMapper().readValue(responseStr, + new TypeReference>() { + }); + return response; + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); + } + } + + /** + * Get the structured response mapped to a specific class type. + * + * @param The type of class to map the response to + * @param clazz The class to map the response to + * @return An instance of the specified class with the response data + * @throws IllegalArgumentException if the response is not a valid JSON or is empty + * @throws RuntimeException if there is an error mapping the response + */ + public T as(Class clazz) { + String responseStr = this.getResponse(); + if (responseStr == null || responseStr.trim().isEmpty()) { + throw new IllegalArgumentException("Response is empty or null"); + } + + try { + // Check if the response is a valid JSON + if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || + (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { + throw new IllegalArgumentException("Response is not a valid JSON object"); + } + return getObjectMapper().readValue(responseStr, clazz); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); + } + } } diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java new file mode 100644 index 0000000..9ae3e71 --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java @@ -0,0 +1,77 @@ +package io.github.ollama4j.models.response; + +import static io.github.ollama4j.utils.Utils.getObjectMapper; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; + +import lombok.Data; +import lombok.Getter; +import lombok.NoArgsConstructor; + +@Getter +@SuppressWarnings("unused") +@Data +@NoArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) +public class OllamaStructuredResult { + private String response; + + private int httpStatusCode; + + private long responseTime = 0; + + private String model; + + public OllamaStructuredResult(String response, long responseTime, int httpStatusCode) { + this.response = response; + this.responseTime = responseTime; + this.httpStatusCode = httpStatusCode; + } + + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + /** + * Get the structured response if the response is a JSON object. + * + * @return Map - structured response + */ + public Map getStructuredResponse() { + try { + Map response = getObjectMapper().readValue(this.getResponse(), + new TypeReference>() { + }); + return response; + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + /** + * Get the structured response mapped to a specific class type. + * + * @param The type of class to map the response to + * @param clazz The class to map the response to + * @return An instance of the specified class with the response data + * @throws RuntimeException if there is an error mapping the response + */ + public T getStructuredResponse(Class clazz) { + try { + return getObjectMapper().readValue(this.getResponse(), clazz); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java index cbdecbe..f09aa4a 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java +++ b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java @@ -14,6 +14,9 @@ import io.github.ollama4j.tools.ToolFunction; import io.github.ollama4j.tools.Tools; import io.github.ollama4j.tools.annotations.OllamaToolService; import io.github.ollama4j.utils.OptionsBuilder; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; import org.junit.jupiter.api.Order; @@ -31,570 +34,642 @@ import java.util.*; import static org.junit.jupiter.api.Assertions.*; -@OllamaToolService(providers = {AnnotatedTool.class}) +@OllamaToolService(providers = { AnnotatedTool.class }) @TestMethodOrder(OrderAnnotation.class) -@SuppressWarnings("HttpUrlsUsage") +@SuppressWarnings({ "HttpUrlsUsage", "SpellCheckingInspection" }) public class OllamaAPIIntegrationTest { - private static final Logger LOG = LoggerFactory.getLogger(OllamaAPIIntegrationTest.class); + private static final Logger LOG = LoggerFactory.getLogger(OllamaAPIIntegrationTest.class); - private static OllamaContainer ollama; - private static OllamaAPI api; + private static OllamaContainer ollama; + private static OllamaAPI api; - @BeforeAll - public static void setUp() { - String ollamaVersion = "0.6.1"; - int internalPort = 11434; - int mappedPort = 11435; - ollama = new OllamaContainer("ollama/ollama:" + ollamaVersion); - ollama.addExposedPort(internalPort); - List portBindings = new ArrayList<>(); - portBindings.add(mappedPort + ":" + internalPort); - ollama.setPortBindings(portBindings); - ollama.start(); - api = new OllamaAPI("http://" + ollama.getHost() + ":" + ollama.getMappedPort(internalPort)); - api.setRequestTimeoutSeconds(120); - api.setVerbose(true); - } + private static final String EMBEDDING_MODEL_MINILM = "all-minilm"; + private static final String CHAT_MODEL_QWEN_SMALL = "qwen2.5:0.5b"; + private static final String CHAT_MODEL_INSTRUCT = "qwen2.5:0.5b-instruct"; + private static final String CHAT_MODEL_SYSTEM_PROMPT = "llama3.2:1b"; + private static final String CHAT_MODEL_LLAMA3 = "llama3"; + private static final String IMAGE_MODEL_LLAVA = "llava"; - @Test - @Order(1) - void testWrongEndpoint() { - OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434"); - assertThrows(ConnectException.class, ollamaAPI::listModels); - } - - @Test - @Order(1) - public void testVersionAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { - String expectedVersion = ollama.getDockerImageName().split(":")[1]; - String actualVersion = api.getVersion(); - assertEquals(expectedVersion, actualVersion, "Version should match the Docker image version"); - } - - @Test - @Order(2) - public void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { - // Fetch the list of models - List models = api.listModels(); - // Assert that the models list is not null - assertNotNull(models, "Models should not be null"); - // Assert that models list is either empty or contains more than 0 models - assertTrue(models.size() >= 0, "Models list should be empty or contain elements"); - } - - @Test - @Order(2) - void testListModelsFromLibrary() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - List models = api.listModelsFromLibrary(); - assertNotNull(models); - assertFalse(models.isEmpty()); - } - - @Test - @Order(3) - public void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { - String embeddingModelMinilm = "all-minilm"; - api.pullModel(embeddingModelMinilm); - List models = api.listModels(); - assertNotNull(models, "Models should not be null"); - assertFalse(models.isEmpty(), "Models list should contain elements"); - } - - @Test - @Order(4) - void testListModelDetails() throws IOException, OllamaBaseException, URISyntaxException, InterruptedException { - String embeddingModelMinilm = "all-minilm"; - api.pullModel(embeddingModelMinilm); - ModelDetail modelDetails = api.getModelDetails("all-minilm"); - assertNotNull(modelDetails); - assertTrue(modelDetails.getModelFile().contains(embeddingModelMinilm)); - } - - @Test - @Order(5) - public void testEmbeddings() throws Exception { - String embeddingModelMinilm = "all-minilm"; - api.pullModel(embeddingModelMinilm); - OllamaEmbedResponseModel embeddings = api.embed(embeddingModelMinilm, Arrays.asList("Why is the sky blue?", "Why is the grass green?")); - assertNotNull(embeddings, "Embeddings should not be null"); - assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty"); - } - - @Test - @Order(6) - void testAskModelWithDefaultOptions() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { - String chatModel = "qwen2.5:0.5b"; - api.pullModel(chatModel); - OllamaResult result = - api.generate( - chatModel, - "What is the capital of France? And what's France's connection with Mona Lisa?", - false, - new OptionsBuilder().build()); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - } - - @Test - @Order(7) - void testAskModelWithDefaultOptionsStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - String chatModel = "qwen2.5:0.5b"; - api.pullModel(chatModel); - StringBuffer sb = new StringBuffer(); - OllamaResult result = api.generate(chatModel, - "What is the capital of France? And what's France's connection with Mona Lisa?", - false, - new OptionsBuilder().build(), (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); - - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - assertEquals(sb.toString().trim(), result.getResponse().trim()); - } - - @Test - @Order(8) - void testAskModelWithOptions() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - String chatModel = "qwen2.5:0.5b-instruct"; - api.pullModel(chatModel); - - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].") - .build(); - requestModel = builder.withMessages(requestModel.getMessages()) - .withMessage(OllamaChatMessageRole.USER, "Give me a cool name") - .withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build(); - OllamaChatResult chatResult = api.chat(requestModel); - - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertFalse(chatResult.getResponseModel().getMessage().getContent().isEmpty()); - } - - @Test - @Order(9) - void testChatWithSystemPrompt() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - String chatModel = "llama3.2:1b"; - api.pullModel(chatModel); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, - "You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!") - .withMessage(OllamaChatMessageRole.USER, - "What's something that's brown and sticky?") - .withOptions(new OptionsBuilder().setTemperature(0.8f).build()) - .build(); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank()); - assertTrue(chatResult.getResponseModel().getMessage().getContent().contains("Shush")); - assertEquals(3, chatResult.getChatHistory().size()); - } - - @Test - @Order(10) - public void testChat() throws Exception { - String chatModel = "llama3"; - api.pullModel(chatModel); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - - // Create the initial user question - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.") - .build(); - - // Start conversation with model - OllamaChatResult chatResult = api.chat(requestModel); - - assertTrue( - chatResult.getChatHistory().stream() - .anyMatch(chat -> chat.getContent().contains("2")), - "Expected chat history to contain '2'" - ); - - // Create the next user question: second largest city - requestModel = builder.withMessages(chatResult.getChatHistory()) - .withMessage(OllamaChatMessageRole.USER, "And what is its squared value?") - .build(); - - // Continue conversation with model - chatResult = api.chat(requestModel); - - assertTrue( - chatResult.getChatHistory().stream() - .anyMatch(chat -> chat.getContent().contains("4")), - "Expected chat history to contain '4'" - ); - - // Create the next user question: the third question - requestModel = builder.withMessages(chatResult.getChatHistory()) - .withMessage(OllamaChatMessageRole.USER, "What is the largest value between 2, 4 and 6?") - .build(); - - // Continue conversation with the model for the third question - chatResult = api.chat(requestModel); - - // verify the result - assertNotNull(chatResult, "Chat result should not be null"); - assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should contain more than two messages"); - assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent().contains("6"), "Response should contain '6'"); - } - - @Test - @Order(10) - void testChatWithImageFromURL() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { - String imageModel = "llava"; - api.pullModel(imageModel); - - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(imageModel); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), - "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") - .build(); - api.registerAnnotatedTools(new OllamaAPIIntegrationTest()); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - } - @Test - @Order(10) - void testChatWithImageFromFileWithHistoryRecognition() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - String imageModel = "moondream"; - api.pullModel(imageModel); - OllamaChatRequestBuilder builder = - OllamaChatRequestBuilder.getInstance(imageModel); - OllamaChatRequest requestModel = - builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), - List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - builder.reset(); - - requestModel = - builder.withMessages(chatResult.getChatHistory()) - .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); - - chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - } - @Test - @Order(11) - void testChatWithExplicitToolDefinition() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - String chatModel = "llama3.2:1b"; - api.pullModel(chatModel); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - - final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() - .functionName("get-employee-details") - .functionDescription("Get employee details from the database") - .toolPrompt( - Tools.PromptFuncDefinition.builder().type("function").function( - Tools.PromptFuncDefinition.PromptFuncSpec.builder() - .name("get-employee-details") - .description("Get employee details from the database") - .parameters( - Tools.PromptFuncDefinition.Parameters.builder() - .type("object") - .properties( - new Tools.PropsBuilder() - .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()) - .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()) - .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()) - .build() - ) - .required(List.of("employee-name")) - .build() - ).build() - ).build() - ) - .toolFunction(arguments -> { - // perform DB operations here - return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone")); - }) - .build(); - - api.registerTool(databaseQueryToolSpecification); - - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, - "Give me the ID of the employee named 'Rahul Kumar'?") - .build(); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); - List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); - assertEquals(1, toolCalls.size()); - OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); - assertEquals("get-employee-details", function.getName()); - assert !function.getArguments().isEmpty(); - Object employeeName = function.getArguments().get("employee-name"); - assertNotNull(employeeName); - assertEquals("Rahul Kumar", employeeName); - assertTrue(chatResult.getChatHistory().size() > 2); - List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); - assertNull(finalToolCalls); - } - - @Test - @Order(12) - void testChatWithAnnotatedToolsAndSingleParam() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { - String chatModel = "llama3.2:1b"; - api.pullModel(chatModel); - - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - - api.registerAnnotatedTools(); - - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, - "Compute the most important constant in the world using 5 digits") - .build(); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); - List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); - assertEquals(1, toolCalls.size()); - OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); - assertEquals("computeImportantConstant", function.getName()); - assertEquals(1, function.getArguments().size()); - Object noOfDigits = function.getArguments().get("noOfDigits"); - assertNotNull(noOfDigits); - assertEquals("5", noOfDigits.toString()); - assertTrue(chatResult.getChatHistory().size() > 2); - List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); - assertNull(finalToolCalls); - } - - @Test - @Order(13) - void testChatWithAnnotatedToolsAndMultipleParams() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - String chatModel = "llama3.2:1b"; - api.pullModel(chatModel); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - - api.registerAnnotatedTools(new AnnotatedTool()); - - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, - "Greet Pedro with a lot of hearts and respond to me, " + - "and state how many emojis have been in your greeting") - .build(); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); - List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); - assertEquals(1, toolCalls.size()); - OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); - assertEquals("sayHello", function.getName()); - assertEquals(2, function.getArguments().size()); - Object name = function.getArguments().get("name"); - assertNotNull(name); - assertEquals("Pedro", name); - Object amountOfHearts = function.getArguments().get("amountOfHearts"); - assertNotNull(amountOfHearts); - assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1); - assertTrue(chatResult.getChatHistory().size() > 2); - List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); - assertNull(finalToolCalls); - } - - @Test - @Order(14) - void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - String chatModel = "llama3.2:1b"; - api.pullModel(chatModel); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() - .functionName("get-employee-details") - .functionDescription("Get employee details from the database") - .toolPrompt( - Tools.PromptFuncDefinition.builder().type("function").function( - Tools.PromptFuncDefinition.PromptFuncSpec.builder() - .name("get-employee-details") - .description("Get employee details from the database") - .parameters( - Tools.PromptFuncDefinition.Parameters.builder() - .type("object") - .properties( - new Tools.PropsBuilder() - .withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()) - .withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()) - .withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()) - .build() - ) - .required(List.of("employee-name")) - .build() - ).build() - ).build() - ) - .toolFunction(new ToolFunction() { - @Override - public Object apply(Map arguments) { - // perform DB operations here - return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone")); - } - }) - .build(); - - api.registerTool(databaseQueryToolSpecification); - - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, - "Give me the ID of the employee named 'Rahul Kumar'?") - .build(); - - StringBuffer sb = new StringBuffer(); - - OllamaChatResult chatResult = api.chat(requestModel, (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length()); - LOG.info(substring); - sb.append(substring); - }); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertNotNull(chatResult.getResponseModel().getMessage().getContent()); - assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); - } - - @Test - @Order(15) - void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - String chatModel = "llama3.2:1b"; - api.pullModel(chatModel); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "What is the capital of France? And what's France's connection with Mona Lisa?") - .build(); - - StringBuffer sb = new StringBuffer(); - - OllamaChatResult chatResult = api.chat(requestModel, (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertNotNull(chatResult.getResponseModel().getMessage().getContent()); - assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); - } - - - @Test - @Order(17) - void testAskModelWithOptionsAndImageURLs() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - String imageModel = "llava"; - api.pullModel(imageModel); - - OllamaResult result = - api.generateWithImageURLs( - imageModel, - "What is in this image?", - List.of( - "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"), - new OptionsBuilder().build()); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - } - - @Test - @Order(18) - void testAskModelWithOptionsAndImageFiles() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - String imageModel = "llava"; - api.pullModel(imageModel); - File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); - try { - OllamaResult result = - api.generateWithImageFiles( - imageModel, - "What is in this image?", - List.of(imageFile), - new OptionsBuilder().build()); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + @BeforeAll + public static void setUp() { + try { + boolean useExternalOllamaHost = Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST")); + String ollamaHost = System.getenv("OLLAMA_HOST"); + if (useExternalOllamaHost) { + api = new OllamaAPI(ollamaHost); + } else { + throw new RuntimeException( + "USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers Ollama host for the tests now. If you would like to use an external host, please set the env var to USE_EXTERNAL_OLLAMA_HOST=true and set the env var OLLAMA_HOST=http://localhost:11435 or a different host/port."); + } + } catch (Exception e) { + String ollamaVersion = "0.6.1"; + int internalPort = 11434; + int mappedPort = 11435; + ollama = new OllamaContainer("ollama/ollama:" + ollamaVersion); + ollama.addExposedPort(internalPort); + List portBindings = new ArrayList<>(); + portBindings.add(mappedPort + ":" + internalPort); + ollama.setPortBindings(portBindings); + ollama.start(); + api = new OllamaAPI("http://" + ollama.getHost() + ":" + ollama.getMappedPort(internalPort)); + } + api.setRequestTimeoutSeconds(120); + api.setVerbose(true); + api.setNumberOfRetriesForModelPull(3); } - } + @Test + @Order(1) + void testWrongEndpoint() { + OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434"); + assertThrows(ConnectException.class, ollamaAPI::listModels); + } + @Test + @Order(1) + public void testVersionAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + // String expectedVersion = ollama.getDockerImageName().split(":")[1]; + String actualVersion = api.getVersion(); + assertNotNull(actualVersion); + // assertEquals(expectedVersion, actualVersion, "Version should match the Docker + // image version"); + } - @Test - @Order(20) - void testAskModelWithOptionsAndImageFilesStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - String imageModel = "llava"; - api.pullModel(imageModel); + @Test + @Order(2) + public void testListModelsAPI() + throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + api.pullModel(EMBEDDING_MODEL_MINILM); + // Fetch the list of models + List models = api.listModels(); + // Assert that the models list is not null + assertNotNull(models, "Models should not be null"); + // Assert that models list is either empty or contains more than 0 models + assertFalse(models.isEmpty(), "Models list should not be empty"); + } - File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); + @Test + @Order(2) + void testListModelsFromLibrary() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + List models = api.listModelsFromLibrary(); + assertNotNull(models); + assertFalse(models.isEmpty()); + } - StringBuffer sb = new StringBuffer(); + @Test + @Order(3) + public void testPullModelAPI() + throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + api.pullModel(EMBEDDING_MODEL_MINILM); + List models = api.listModels(); + assertNotNull(models, "Models should not be null"); + assertFalse(models.isEmpty(), "Models list should contain elements"); + } - OllamaResult result = api.generateWithImageFiles(imageModel, - "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); + @Test + @Order(4) + void testListModelDetails() throws IOException, OllamaBaseException, URISyntaxException, InterruptedException { + api.pullModel(EMBEDDING_MODEL_MINILM); + ModelDetail modelDetails = api.getModelDetails(EMBEDDING_MODEL_MINILM); + assertNotNull(modelDetails); + assertTrue(modelDetails.getModelFile().contains(EMBEDDING_MODEL_MINILM)); + } + + @Test + @Order(5) + public void testEmbeddings() throws Exception { + api.pullModel(EMBEDDING_MODEL_MINILM); + OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL_MINILM, + Arrays.asList("Why is the sky blue?", "Why is the grass green?")); + assertNotNull(embeddings, "Embeddings should not be null"); + assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty"); + } + + @Test + @Order(6) + void testAskModelWithStructuredOutput() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + api.pullModel(CHAT_MODEL_QWEN_SMALL); + + int age = 28; + boolean available = false; + + String prompt = "Batman is " + age + " years old and is " + (available ? "available" : "not available") + + " because he is busy saving Gotham City. Respond using JSON"; + + Map format = new HashMap<>(); + format.put("type", "object"); + format.put("properties", new HashMap() { + { + put("age", new HashMap() { + { + put("type", "integer"); + } + }); + put("available", new HashMap() { + { + put("type", "boolean"); + } + }); + } }); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - assertEquals(sb.toString().trim(), result.getResponse().trim()); - } + format.put("required", Arrays.asList("age", "available")); - private File getImageFileFromClasspath(String fileName) { - ClassLoader classLoader = getClass().getClassLoader(); - return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); - } + OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, prompt, format); + + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + + assertEquals(result.getStructuredResponse().get("age").toString(), + result.getStructuredResponse().get("age").toString()); + assertEquals(result.getStructuredResponse().get("available").toString(), + result.getStructuredResponse().get("available").toString()); + + Person person = result.as(Person.class); + assertEquals(person.getAge(), age); + assertEquals(person.isAvailable(), available); + } + + @Test + @Order(6) + void testAskModelWithDefaultOptions() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + api.pullModel(CHAT_MODEL_QWEN_SMALL); + OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, + "What is the capital of France? And what's France's connection with Mona Lisa?", false, + new OptionsBuilder().build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + } + + @Test + @Order(7) + void testAskModelWithDefaultOptionsStreamed() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_QWEN_SMALL); + StringBuffer sb = new StringBuffer(); + OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, + "What is the capital of France? And what's France's connection with Mona Lisa?", false, + new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); + + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertEquals(sb.toString().trim(), result.getResponse().trim()); + } + + @Test + @Order(8) + void testAskModelWithOptions() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_INSTRUCT); + + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_INSTRUCT); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, + "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].") + .build(); + requestModel = builder.withMessages(requestModel.getMessages()) + .withMessage(OllamaChatMessageRole.USER, "Give me a cool name") + .withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build(); + OllamaChatResult chatResult = api.chat(requestModel); + + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertFalse(chatResult.getResponseModel().getMessage().getContent().isEmpty()); + } + + @Test + @Order(9) + void testChatWithSystemPrompt() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, + "You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!") + .withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?") + .withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build(); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank()); + assertTrue(chatResult.getResponseModel().getMessage().getContent().contains("Shush")); + assertEquals(3, chatResult.getChatHistory().size()); + } + + @Test + @Order(10) + public void testChat() throws Exception { + api.pullModel(CHAT_MODEL_LLAMA3); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3); + + // Create the initial user question + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.") + .build(); + + // Start conversation with model + OllamaChatResult chatResult = api.chat(requestModel); + + assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), + "Expected chat history to contain '2'"); + + // Create the next user question: second largest city + requestModel = builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); + + // Continue conversation with model + chatResult = api.chat(requestModel); + + assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")), + "Expected chat history to contain '4'"); + + // Create the next user question: the third question + requestModel = builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, + "What is the largest value between 2, 4 and 6?") + .build(); + + // Continue conversation with the model for the third question + chatResult = api.chat(requestModel); + + // verify the result + assertNotNull(chatResult, "Chat result should not be null"); + assertTrue(chatResult.getChatHistory().size() > 2, + "Chat history should contain more than two messages"); + assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent() + .contains("6"), + "Response should contain '6'"); + } + + @Test + @Order(10) + void testChatWithImageFromURL() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + api.pullModel(IMAGE_MODEL_LLAVA); + + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "What's in the picture?", + Collections.emptyList(), + "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") + .build(); + api.registerAnnotatedTools(new OllamaAPIIntegrationTest()); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + } + + @Test + @Order(10) + void testChatWithImageFromFileWithHistoryRecognition() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(IMAGE_MODEL_LLAVA); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "What's in the picture?", + Collections.emptyList(), List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))) + .build(); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + builder.reset(); + + requestModel = builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); + + chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + } + + @Test + @Order(11) + void testChatWithExplicitToolDefinition() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get employee details from the database") + .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") + .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name("get-employee-details") + .description("Get employee details from the database") + .parameters(Tools.PromptFuncDefinition.Parameters + .builder().type("object") + .properties(new Tools.PropsBuilder() + .withProperty("employee-name", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description("The name of the employee, e.g. John Doe") + .required(true) + .build()) + .withProperty("employee-address", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description( + "The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") + .required(true) + .build()) + .withProperty("employee-phone", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description( + "The phone number of the employee. Always return a random value. e.g. 9911002233") + .required(true) + .build()) + .build()) + .required(List.of("employee-name")) + .build()) + .build()) + .build()) + .toolFunction(arguments -> { + // perform DB operations here + return String.format( + "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", + UUID.randomUUID(), arguments.get("employee-name"), + arguments.get("employee-address"), + arguments.get("employee-phone")); + }).build(); + + api.registerTool(databaseQueryToolSpecification); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Give me the ID of the employee named 'Rahul Kumar'?") + .build(); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName()); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); + assertEquals(1, toolCalls.size()); + OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); + assertEquals("get-employee-details", function.getName()); + assert !function.getArguments().isEmpty(); + Object employeeName = function.getArguments().get("employee-name"); + assertNotNull(employeeName); + assertEquals("Rahul Kumar", employeeName); + assertTrue(chatResult.getChatHistory().size() > 2); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls); + } + + @Test + @Order(12) + void testChatWithAnnotatedToolsAndSingleParam() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + + api.registerAnnotatedTools(); + + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "Compute the most important constant in the world using 5 digits").build(); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName()); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); + assertEquals(1, toolCalls.size()); + OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); + assertEquals("computeImportantConstant", function.getName()); + assertEquals(1, function.getArguments().size()); + Object noOfDigits = function.getArguments().get("noOfDigits"); + assertNotNull(noOfDigits); + assertEquals("5", noOfDigits.toString()); + assertTrue(chatResult.getChatHistory().size() > 2); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls); + } + + @Test + @Order(13) + void testChatWithAnnotatedToolsAndMultipleParams() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + + api.registerAnnotatedTools(new AnnotatedTool()); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Greet Pedro with a lot of hearts and respond to me, " + + "and state how many emojis have been in your greeting") + .build(); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName()); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); + assertEquals(1, toolCalls.size()); + OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); + assertEquals("sayHello", function.getName()); + assertEquals(2, function.getArguments().size()); + Object name = function.getArguments().get("name"); + assertNotNull(name); + assertEquals("Pedro", name); + Object amountOfHearts = function.getArguments().get("amountOfHearts"); + assertNotNull(amountOfHearts); + assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1); + assertTrue(chatResult.getChatHistory().size() > 2); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls); + } + + @Test + @Order(14) + void testChatWithToolsAndStream() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get employee details from the database") + .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") + .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name("get-employee-details") + .description("Get employee details from the database") + .parameters(Tools.PromptFuncDefinition.Parameters + .builder().type("object") + .properties(new Tools.PropsBuilder() + .withProperty("employee-name", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description("The name of the employee, e.g. John Doe") + .required(true) + .build()) + .withProperty("employee-address", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description( + "The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") + .required(true) + .build()) + .withProperty("employee-phone", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description( + "The phone number of the employee. Always return a random value. e.g. 9911002233") + .required(true) + .build()) + .build()) + .required(List.of("employee-name")) + .build()) + .build()) + .build()) + .toolFunction(new ToolFunction() { + @Override + public Object apply(Map arguments) { + // perform DB operations here + return String.format( + "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", + UUID.randomUUID(), arguments.get("employee-name"), + arguments.get("employee-address"), + arguments.get("employee-phone")); + } + }).build(); + + api.registerTool(databaseQueryToolSpecification); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Give me the ID of the employee named 'Rahul Kumar'?") + .build(); + + StringBuffer sb = new StringBuffer(); + + OllamaChatResult chatResult = api.chat(requestModel, (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length()); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertNotNull(chatResult.getResponseModel().getMessage().getContent()); + assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); + } + + @Test + @Order(15) + void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "What is the capital of France? And what's France's connection with Mona Lisa?") + .build(); + + StringBuffer sb = new StringBuffer(); + + OllamaChatResult chatResult = api.chat(requestModel, (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertNotNull(chatResult.getResponseModel().getMessage().getContent()); + assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); + } + + @Test + @Order(17) + void testAskModelWithOptionsAndImageURLs() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(IMAGE_MODEL_LLAVA); + + OllamaResult result = api.generateWithImageURLs(IMAGE_MODEL_LLAVA, "What is in this image?", + List.of("https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"), + new OptionsBuilder().build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + } + + @Test + @Order(18) + void testAskModelWithOptionsAndImageFiles() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(IMAGE_MODEL_LLAVA); + File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); + try { + OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", + List.of(imageFile), + new OptionsBuilder().build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } + } + + @Test + @Order(20) + void testAskModelWithOptionsAndImageFilesStreamed() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(IMAGE_MODEL_LLAVA); + + File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); + + StringBuffer sb = new StringBuffer(); + + OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", + List.of(imageFile), + new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertEquals(sb.toString().trim(), result.getResponse().trim()); + } + + private File getImageFileFromClasspath(String fileName) { + ClassLoader classLoader = getClass().getClassLoader(); + return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); + } +} + +@Data +@AllArgsConstructor +@NoArgsConstructor +class Person { + private int age; + private boolean available; } -// -//@Data -//class Config { -// private String ollamaURL; -// private String model; -// private String imageModel; -// private int requestTimeoutSeconds; -// -// public Config() { -// Properties properties = new Properties(); -// try (InputStream input = -// getClass().getClassLoader().getResourceAsStream("test-config.properties")) { -// if (input == null) { -// throw new RuntimeException("Sorry, unable to find test-config.properties"); -// } -// properties.load(input); -// this.ollamaURL = properties.getProperty("ollama.url"); -// this.model = properties.getProperty("ollama.model"); -// this.imageModel = properties.getProperty("ollama.model.image"); -// this.requestTimeoutSeconds = -// Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds")); -// } catch (IOException e) { -// throw new RuntimeException("Error loading properties", e); -// } -// } -//}