From bc2a93158647109200c1d248693e1693c86b66fd Mon Sep 17 00:00:00 2001 From: Amith Koujalgi Date: Mon, 24 Mar 2025 21:40:20 +0530 Subject: [PATCH] Enhance OllamaAPI and OllamaResult for improved model pulling and structured responses - Added a retry mechanism in OllamaAPI for model pulling, allowing configurable retries. - Introduced new methods in OllamaResult for structured response handling, including parsing JSON responses into a Map or specific class types. - Updated integration tests to validate the new functionality and ensure robust testing of model interactions. - Improved code formatting and consistency across the OllamaAPI and integration test classes. --- .../java/io/github/ollama4j/OllamaAPI.java | 62 +- .../models/response/OllamaResult.java | 74 +- .../response/OllamaStructuredResult.java | 77 ++ .../OllamaAPIIntegrationTest.java | 1217 +++++++++-------- 4 files changed, 813 insertions(+), 617 deletions(-) create mode 100644 src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 955d4dd..1c0abd5 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -51,7 +51,7 @@ import java.util.stream.Collectors; /** * The base Ollama API class. */ -@SuppressWarnings({"DuplicatedCode", "resource"}) +@SuppressWarnings({ "DuplicatedCode", "resource" }) public class OllamaAPI { private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); @@ -74,6 +74,12 @@ public class OllamaAPI { private Auth auth; + private int numberOfRetriesForModelPull = 0; + + public void setNumberOfRetriesForModelPull(int numberOfRetriesForModelPull) { + this.numberOfRetriesForModelPull = numberOfRetriesForModelPull; + } + private final ToolRegistry toolRegistry = new ToolRegistry(); /** @@ -209,7 +215,7 @@ public class OllamaAPI { * tags, tag count, and the time when model was updated. * * @return A list of {@link LibraryModel} objects representing the models - * available in the Ollama library. + * available in the Ollama library. * @throws OllamaBaseException If the HTTP request fails or the response is not * successful (non-200 status code). * @throws IOException If an I/O error occurs during the HTTP request @@ -275,7 +281,7 @@ public class OllamaAPI { * of the library model * for which the tags need to be fetched. * @return a list of {@link LibraryModelTag} objects containing the extracted - * tags and their associated metadata. + * tags and their associated metadata. * @throws OllamaBaseException if the HTTP response status code indicates an * error (i.e., not 200 OK), * or if there is any other issue during the @@ -342,7 +348,7 @@ public class OllamaAPI { * @param modelName The name of the model to search for in the library. * @param tag The tag name to search for within the specified model. * @return The {@link LibraryModelTag} associated with the specified model and - * tag. + * tag. * @throws OllamaBaseException If there is a problem with the Ollama library * operations. * @throws IOException If an I/O error occurs during the operation. @@ -376,6 +382,26 @@ public class OllamaAPI { */ public void pullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + if (numberOfRetriesForModelPull == 0) { + this.doPullModel(modelName); + } else { + int numberOfRetries = 0; + while (numberOfRetries < numberOfRetriesForModelPull) { + try { + this.doPullModel(modelName); + return; + } catch (OllamaBaseException e) { + logger.error("Failed to pull model " + modelName + ", retrying..."); + numberOfRetries++; + } + } + throw new OllamaBaseException( + "Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries"); + } + } + + private void doPullModel(String modelName) + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String url = this.host + "/api/pull"; String jsonData = new ModelRequest(modelName).toString(); HttpRequest request = getRequestBuilderDefault(new URI(url)) @@ -729,7 +755,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted */ public OllamaResult generate(String model, String prompt, boolean raw, Options options, - OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); ollamaRequestModel.setRaw(raw); ollamaRequestModel.setOptions(options.getOptionsMap()); @@ -742,8 +768,10 @@ public class OllamaAPI { * @param model The name or identifier of the AI model to use for generating * the response. * @param prompt The input text or prompt to provide to the AI model. - * @param format A map containing the format specification for the structured output. - * @return An instance of {@link OllamaResult} containing the structured response. + * @param format A map containing the format specification for the structured + * output. + * @return An instance of {@link OllamaResult} containing the structured + * response. * @throws OllamaBaseException if the response indicates an error status. * @throws IOException if an I/O error occurs during the HTTP request. * @throws InterruptedException if the operation is interrupted. @@ -771,7 +799,11 @@ public class OllamaAPI { String responseBody = response.body(); if (statusCode == 200) { - return Utils.getObjectMapper().readValue(responseBody, OllamaResult.class); + OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, + OllamaStructuredResult.class); + OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), + structuredResult.getResponseTime(), statusCode); + return ollamaResult; } else { throw new OllamaBaseException(statusCode + " - " + responseBody); } @@ -813,8 +845,8 @@ public class OllamaAPI { * @param options Additional options or configurations to use when generating * the response. * @return {@link OllamaToolsResult} An OllamaToolsResult object containing the - * response from the AI model and the results of invoking the tools on - * that output. + * response from the AI model and the results of invoking the tools on + * that output. * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted @@ -906,7 +938,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted */ public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options, - OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { List images = new ArrayList<>(); for (File imageFile : imageFiles) { images.add(encodeFileToBase64(imageFile)); @@ -953,7 +985,7 @@ public class OllamaAPI { * @throws URISyntaxException if the URI for the request is malformed */ public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options, - OllamaStreamHandler streamHandler) + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { List images = new ArrayList<>(); for (String imageURL : imageURLs) { @@ -988,7 +1020,7 @@ public class OllamaAPI { * @param model the ollama model to ask the question to * @param messages chat history / message stack to send to the model * @return {@link OllamaChatResult} containing the api response and the message - * history including the newly aqcuired assistant response. + * history including the newly aqcuired assistant response. * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read * @throws InterruptedException in case the server is not reachable or network @@ -1171,7 +1203,7 @@ public class OllamaAPI { registerAnnotatedTools(provider.getDeclaredConstructor().newInstance()); } } catch (InstantiationException | NoSuchMethodException | IllegalAccessException - | InvocationTargetException e) { + | InvocationTargetException e) { throw new RuntimeException(e); } } @@ -1308,7 +1340,7 @@ public class OllamaAPI { * @throws InterruptedException if the thread is interrupted during the request. */ private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, - OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, verbose); OllamaResult result; diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java index 2d340e0..4b538f9 100644 --- a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java +++ b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java @@ -1,19 +1,26 @@ package io.github.ollama4j.models.response; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; + import lombok.Data; import lombok.Getter; import static io.github.ollama4j.utils.Utils.getObjectMapper; +import java.util.HashMap; +import java.util.Map; + /** The type Ollama result. */ @Getter @SuppressWarnings("unused") @Data +@JsonIgnoreProperties(ignoreUnknown = true) public class OllamaResult { /** * -- GETTER -- - * Get the completion/response text + * Get the completion/response text * * @return String completion/response text */ @@ -21,7 +28,7 @@ public class OllamaResult { /** * -- GETTER -- - * Get the response status code. + * Get the response status code. * * @return int - response status code */ @@ -29,7 +36,7 @@ public class OllamaResult { /** * -- GETTER -- - * Get the response time in milliseconds. + * Get the response time in milliseconds. * * @return long - response time in milliseconds */ @@ -44,9 +51,68 @@ public class OllamaResult { @Override public String toString() { try { - return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + Map responseMap = new HashMap<>(); + responseMap.put("response", this.response); + responseMap.put("httpStatusCode", this.httpStatusCode); + responseMap.put("responseTime", this.responseTime); + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap); } catch (JsonProcessingException e) { throw new RuntimeException(e); } } + + /** + * Get the structured response if the response is a JSON object. + * + * @return Map - structured response + * @throws IllegalArgumentException if the response is not a valid JSON object + */ + public Map getStructuredResponse() { + String responseStr = this.getResponse(); + if (responseStr == null || responseStr.trim().isEmpty()) { + throw new IllegalArgumentException("Response is empty or null"); + } + + try { + // Check if the response is a valid JSON + if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || + (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { + throw new IllegalArgumentException("Response is not a valid JSON object"); + } + + Map response = getObjectMapper().readValue(responseStr, + new TypeReference>() { + }); + return response; + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); + } + } + + /** + * Get the structured response mapped to a specific class type. + * + * @param The type of class to map the response to + * @param clazz The class to map the response to + * @return An instance of the specified class with the response data + * @throws IllegalArgumentException if the response is not a valid JSON or is empty + * @throws RuntimeException if there is an error mapping the response + */ + public T as(Class clazz) { + String responseStr = this.getResponse(); + if (responseStr == null || responseStr.trim().isEmpty()) { + throw new IllegalArgumentException("Response is empty or null"); + } + + try { + // Check if the response is a valid JSON + if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) || + (!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) { + throw new IllegalArgumentException("Response is not a valid JSON object"); + } + return getObjectMapper().readValue(responseStr, clazz); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e); + } + } } diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java new file mode 100644 index 0000000..9ae3e71 --- /dev/null +++ b/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java @@ -0,0 +1,77 @@ +package io.github.ollama4j.models.response; + +import static io.github.ollama4j.utils.Utils.getObjectMapper; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; + +import lombok.Data; +import lombok.Getter; +import lombok.NoArgsConstructor; + +@Getter +@SuppressWarnings("unused") +@Data +@NoArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) +public class OllamaStructuredResult { + private String response; + + private int httpStatusCode; + + private long responseTime = 0; + + private String model; + + public OllamaStructuredResult(String response, long responseTime, int httpStatusCode) { + this.response = response; + this.responseTime = responseTime; + this.httpStatusCode = httpStatusCode; + } + + @Override + public String toString() { + try { + return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + /** + * Get the structured response if the response is a JSON object. + * + * @return Map - structured response + */ + public Map getStructuredResponse() { + try { + Map response = getObjectMapper().readValue(this.getResponse(), + new TypeReference>() { + }); + return response; + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + /** + * Get the structured response mapped to a specific class type. + * + * @param The type of class to map the response to + * @param clazz The class to map the response to + * @return An instance of the specified class with the response data + * @throws RuntimeException if there is an error mapping the response + */ + public T getStructuredResponse(Class clazz) { + try { + return getObjectMapper().readValue(this.getResponse(), clazz); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java index b9a9f48..7608610 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java +++ b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java @@ -35,623 +35,644 @@ import java.util.*; import static io.github.ollama4j.utils.Utils.getObjectMapper; import static org.junit.jupiter.api.Assertions.*; -@OllamaToolService(providers = {AnnotatedTool.class}) +@OllamaToolService(providers = { AnnotatedTool.class }) @TestMethodOrder(OrderAnnotation.class) -@SuppressWarnings({"HttpUrlsUsage", "SpellCheckingInspection"}) +@SuppressWarnings({ "HttpUrlsUsage", "SpellCheckingInspection" }) public class OllamaAPIIntegrationTest { - private static final Logger LOG = LoggerFactory.getLogger(OllamaAPIIntegrationTest.class); + private static final Logger LOG = LoggerFactory.getLogger(OllamaAPIIntegrationTest.class); - private static OllamaContainer ollama; - private static OllamaAPI api; + private static OllamaContainer ollama; + private static OllamaAPI api; - private static final String EMBEDDING_MODEL_MINILM = "all-minilm"; - private static final String CHAT_MODEL_DEFAULT = "qwen2.5:0.5b"; - private static final String CHAT_MODEL_INSTRUCT = "qwen2.5:0.5b-instruct"; - private static final String CHAT_MODEL_SYSTEM_PROMPT = "llama3.2:1b"; - private static final String CHAT_MODEL_LLAMA3 = "llama3"; - private static final String IMAGE_MODEL_LLAVA = "llava"; - private static final String IMAGE_MODEL_MOONDREAM = "moondream"; + private static final String EMBEDDING_MODEL_MINILM = "all-minilm"; + private static final String CHAT_MODEL_QWEN_SMALL = "qwen2.5:0.5b"; + private static final String CHAT_MODEL_INSTRUCT = "qwen2.5:0.5b-instruct"; + private static final String CHAT_MODEL_SYSTEM_PROMPT = "llama3.2:1b"; + private static final String CHAT_MODEL_LLAMA3 = "llama3"; + private static final String IMAGE_MODEL_LLAVA = "llava"; - @BeforeAll - public static void setUp() { - try { - boolean useExternalOllamaHost = Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST")); - String ollamaHost = System.getenv("OLLAMA_HOST"); - if (useExternalOllamaHost) { - api = new OllamaAPI(ollamaHost); - } else { - throw new RuntimeException("USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers Ollama host for the tests now. If you would like to use an external host, please set the env var to USE_EXTERNAL_OLLAMA_HOST=true and set the env var OLLAMA_HOST=http://localhost:11435 or a different host/port."); - } - } catch (Exception e) { - String ollamaVersion = "0.6.1"; - int internalPort = 11434; - int mappedPort = 11435; - ollama = new OllamaContainer("ollama/ollama:" + ollamaVersion); - ollama.addExposedPort(internalPort); - List portBindings = new ArrayList<>(); - portBindings.add(mappedPort + ":" + internalPort); - ollama.setPortBindings(portBindings); - ollama.start(); - api = new OllamaAPI("http://" + ollama.getHost() + ":" + ollama.getMappedPort(internalPort)); + @BeforeAll + public static void setUp() { + try { + boolean useExternalOllamaHost = Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST")); + String ollamaHost = System.getenv("OLLAMA_HOST"); + if (useExternalOllamaHost) { + api = new OllamaAPI(ollamaHost); + } else { + throw new RuntimeException( + "USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers Ollama host for the tests now. If you would like to use an external host, please set the env var to USE_EXTERNAL_OLLAMA_HOST=true and set the env var OLLAMA_HOST=http://localhost:11435 or a different host/port."); + } + } catch (Exception e) { + String ollamaVersion = "0.6.1"; + int internalPort = 11434; + int mappedPort = 11435; + ollama = new OllamaContainer("ollama/ollama:" + ollamaVersion); + ollama.addExposedPort(internalPort); + List portBindings = new ArrayList<>(); + portBindings.add(mappedPort + ":" + internalPort); + ollama.setPortBindings(portBindings); + ollama.start(); + api = new OllamaAPI("http://" + ollama.getHost() + ":" + ollama.getMappedPort(internalPort)); + } + api.setRequestTimeoutSeconds(120); + api.setVerbose(true); + api.setNumberOfRetriesForModelPull(3); } - api.setRequestTimeoutSeconds(120); - api.setVerbose(true); - } - @Test - @Order(1) - void testWrongEndpoint() { - OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434"); - assertThrows(ConnectException.class, ollamaAPI::listModels); - } - - @Test - @Order(1) - public void testVersionAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { - // String expectedVersion = ollama.getDockerImageName().split(":")[1]; - String actualVersion = api.getVersion(); - assertNotNull(actualVersion); - // assertEquals(expectedVersion, actualVersion, "Version should match the Docker image version"); - } - - @Test - @Order(2) - public void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { - api.pullModel(EMBEDDING_MODEL_MINILM); - // Fetch the list of models - List models = api.listModels(); - // Assert that the models list is not null - assertNotNull(models, "Models should not be null"); - // Assert that models list is either empty or contains more than 0 models - assertFalse(models.isEmpty(), "Models list should not be empty"); - } - - @Test - @Order(2) - void testListModelsFromLibrary() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - List models = api.listModelsFromLibrary(); - assertNotNull(models); - assertFalse(models.isEmpty()); - } - - @Test - @Order(3) - public void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { - api.pullModel(EMBEDDING_MODEL_MINILM); - List models = api.listModels(); - assertNotNull(models, "Models should not be null"); - assertFalse(models.isEmpty(), "Models list should contain elements"); - } - - @Test - @Order(4) - void testListModelDetails() throws IOException, OllamaBaseException, URISyntaxException, InterruptedException { - api.pullModel(EMBEDDING_MODEL_MINILM); - ModelDetail modelDetails = api.getModelDetails(EMBEDDING_MODEL_MINILM); - assertNotNull(modelDetails); - assertTrue(modelDetails.getModelFile().contains(EMBEDDING_MODEL_MINILM)); - } - - @Test - @Order(5) - public void testEmbeddings() throws Exception { - api.pullModel(EMBEDDING_MODEL_MINILM); - OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL_MINILM, - Arrays.asList("Why is the sky blue?", "Why is the grass green?")); - assertNotNull(embeddings, "Embeddings should not be null"); - assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty"); - } - - @Test - @Order(6) - void testAskModelWithDefaultOptions() - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { - api.pullModel(CHAT_MODEL_DEFAULT); - OllamaResult result = api.generate(CHAT_MODEL_DEFAULT, - "What is the capital of France? And what's France's connection with Mona Lisa?", false, - new OptionsBuilder().build()); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - } - - @Test - @Order(6) - void testAskModelWithStructuredOutput() - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { - api.pullModel(CHAT_MODEL_DEFAULT); - - String prompt = "Ollama is 22 years old and is busy saving the world. Respond using JSON"; - Map format = new HashMap<>(); - format.put("type", "object"); - format.put("properties", new HashMap() { - { - put("age", new HashMap() { - { - put("type", "integer"); - } - }); - put("available", new HashMap() { - { - put("type", "boolean"); - } - }); - } - }); - format.put("required", Arrays.asList("age", "available")); - - OllamaResult result = api.generate(CHAT_MODEL_DEFAULT, prompt, format); - - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - - Map actualResponse = getObjectMapper().readValue(result.getResponse(), new TypeReference<>() { - }); - - int age = 22; - boolean available = true; - String expectedResponseJson = "{\n \"age\": " + age + ",\n \"available\": " + available + "\n}"; - - Map expectedResponse = getObjectMapper().readValue(expectedResponseJson, - new TypeReference>() { - }); - assertEquals(actualResponse.get("age").toString(), expectedResponse.get("age").toString()); - assertEquals(actualResponse.get("available").toString(), expectedResponse.get("available").toString()); - - assertEquals(result.getStructuredResponse().get("age").toString(), - result.getStructuredResponse().get("age").toString()); - assertEquals(result.getStructuredResponse().get("available").toString(), - result.getStructuredResponse().get("available").toString()); - - Person person = result.getStructuredResponse(Person.class); - assertEquals(person.getAge(), age); - assertEquals(person.isAvailable(), available); - } - - @Test - @Order(7) - void testAskModelWithDefaultOptionsStreamed() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(CHAT_MODEL_DEFAULT); - StringBuffer sb = new StringBuffer(); - OllamaResult result = api.generate(CHAT_MODEL_DEFAULT, - "What is the capital of France? And what's France's connection with Mona Lisa?", false, - new OptionsBuilder().build(), (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); - - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - assertEquals(sb.toString().trim(), result.getResponse().trim()); - } - - @Test - @Order(8) - void testAskModelWithOptions() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(CHAT_MODEL_INSTRUCT); - - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_INSTRUCT); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, - "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].") - .build(); - requestModel = builder.withMessages(requestModel.getMessages()) - .withMessage(OllamaChatMessageRole.USER, "Give me a cool name") - .withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build(); - OllamaChatResult chatResult = api.chat(requestModel); - - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertFalse(chatResult.getResponseModel().getMessage().getContent().isEmpty()); - } - - @Test - @Order(9) - void testChatWithSystemPrompt() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, - "You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!") - .withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?") - .withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build(); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank()); - assertTrue(chatResult.getResponseModel().getMessage().getContent().contains("Shush")); - assertEquals(3, chatResult.getChatHistory().size()); - } - - @Test - @Order(10) - public void testChat() throws Exception { - api.pullModel(CHAT_MODEL_LLAMA3); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3); - - // Create the initial user question - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.").build(); - - // Start conversation with model - OllamaChatResult chatResult = api.chat(requestModel); - - assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), - "Expected chat history to contain '2'"); - - // Create the next user question: second largest city - requestModel = builder.withMessages(chatResult.getChatHistory()) - .withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); - - // Continue conversation with model - chatResult = api.chat(requestModel); - - assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")), - "Expected chat history to contain '4'"); - - // Create the next user question: the third question - requestModel = builder.withMessages(chatResult.getChatHistory()) - .withMessage(OllamaChatMessageRole.USER, "What is the largest value between 2, 4 and 6?").build(); - - // Continue conversation with the model for the third question - chatResult = api.chat(requestModel); - - // verify the result - assertNotNull(chatResult, "Chat result should not be null"); - assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should contain more than two messages"); - assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent().contains("6"), - "Response should contain '6'"); - } - - @Test - @Order(10) - void testChatWithImageFromURL() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { - api.pullModel(IMAGE_MODEL_LLAVA); - - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), - "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") - .build(); - api.registerAnnotatedTools(new OllamaAPIIntegrationTest()); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - } - - @Test - @Order(10) - void testChatWithImageFromFileWithHistoryRecognition() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(IMAGE_MODEL_MOONDREAM); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_MOONDREAM); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", - Collections.emptyList(), List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - builder.reset(); - - requestModel = builder.withMessages(chatResult.getChatHistory()) - .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); - - chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - } - - @Test - @Order(11) - void testChatWithExplicitToolDefinition() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); - - final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() - .functionName("get-employee-details").functionDescription("Get employee details from the database") - .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") - .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details") - .description("Get employee details from the database") - .parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object") - .properties(new Tools.PropsBuilder() - .withProperty("employee-name", - Tools.PromptFuncDefinition.Property.builder().type("string") - .description("The name of the employee, e.g. John Doe") - .required(true).build()) - .withProperty("employee-address", Tools.PromptFuncDefinition.Property - .builder().type("string") - .description( - "The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") - .required(true).build()) - .withProperty("employee-phone", Tools.PromptFuncDefinition.Property - .builder().type("string") - .description( - "The phone number of the employee. Always return a random value. e.g. 9911002233") - .required(true).build()) - .build()) - .required(List.of("employee-name")).build()) - .build()) - .build()) - .toolFunction(arguments -> { - // perform DB operations here - return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", - UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), - arguments.get("employee-phone")); - }).build(); - - api.registerTool(databaseQueryToolSpecification); - - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build(); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), - chatResult.getResponseModel().getMessage().getRole().getRoleName()); - List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); - assertEquals(1, toolCalls.size()); - OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); - assertEquals("get-employee-details", function.getName()); - assert !function.getArguments().isEmpty(); - Object employeeName = function.getArguments().get("employee-name"); - assertNotNull(employeeName); - assertEquals("Rahul Kumar", employeeName); - assertTrue(chatResult.getChatHistory().size() > 2); - List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); - assertNull(finalToolCalls); - } - - @Test - @Order(12) - void testChatWithAnnotatedToolsAndSingleParam() - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); - - api.registerAnnotatedTools(); - - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "Compute the most important constant in the world using 5 digits").build(); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), - chatResult.getResponseModel().getMessage().getRole().getRoleName()); - List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); - assertEquals(1, toolCalls.size()); - OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); - assertEquals("computeImportantConstant", function.getName()); - assertEquals(1, function.getArguments().size()); - Object noOfDigits = function.getArguments().get("noOfDigits"); - assertNotNull(noOfDigits); - assertEquals("5", noOfDigits.toString()); - assertTrue(chatResult.getChatHistory().size() > 2); - List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); - assertNull(finalToolCalls); - } - - @Test - @Order(13) - void testChatWithAnnotatedToolsAndMultipleParams() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); - - api.registerAnnotatedTools(new AnnotatedTool()); - - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, "Greet Pedro with a lot of hearts and respond to me, " - + "and state how many emojis have been in your greeting") - .build(); - - OllamaChatResult chatResult = api.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), - chatResult.getResponseModel().getMessage().getRole().getRoleName()); - List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); - assertEquals(1, toolCalls.size()); - OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); - assertEquals("sayHello", function.getName()); - assertEquals(2, function.getArguments().size()); - Object name = function.getArguments().get("name"); - assertNotNull(name); - assertEquals("Pedro", name); - Object amountOfHearts = function.getArguments().get("amountOfHearts"); - assertNotNull(amountOfHearts); - assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1); - assertTrue(chatResult.getChatHistory().size() > 2); - List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); - assertNull(finalToolCalls); - } - - @Test - @Order(14) - void testChatWithToolsAndStream() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); - final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() - .functionName("get-employee-details").functionDescription("Get employee details from the database") - .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") - .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details") - .description("Get employee details from the database") - .parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object") - .properties(new Tools.PropsBuilder() - .withProperty("employee-name", - Tools.PromptFuncDefinition.Property.builder().type("string") - .description("The name of the employee, e.g. John Doe") - .required(true).build()) - .withProperty("employee-address", Tools.PromptFuncDefinition.Property - .builder().type("string") - .description( - "The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") - .required(true).build()) - .withProperty("employee-phone", Tools.PromptFuncDefinition.Property - .builder().type("string") - .description( - "The phone number of the employee. Always return a random value. e.g. 9911002233") - .required(true).build()) - .build()) - .required(List.of("employee-name")).build()) - .build()) - .build()) - .toolFunction(new ToolFunction() { - @Override - public Object apply(Map arguments) { - // perform DB operations here - return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", - UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), - arguments.get("employee-phone")); - } - }).build(); - - api.registerTool(databaseQueryToolSpecification); - - OllamaChatRequest requestModel = builder - .withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build(); - - StringBuffer sb = new StringBuffer(); - - OllamaChatResult chatResult = api.chat(requestModel, (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length()); - LOG.info(substring); - sb.append(substring); - }); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertNotNull(chatResult.getResponseModel().getMessage().getContent()); - assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); - } - - @Test - @Order(15) - void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "What is the capital of France? And what's France's connection with Mona Lisa?").build(); - - StringBuffer sb = new StringBuffer(); - - OllamaChatResult chatResult = api.chat(requestModel, (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponseModel()); - assertNotNull(chatResult.getResponseModel().getMessage()); - assertNotNull(chatResult.getResponseModel().getMessage().getContent()); - assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); - } - - @Test - @Order(17) - void testAskModelWithOptionsAndImageURLs() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(IMAGE_MODEL_LLAVA); - - OllamaResult result = api.generateWithImageURLs(IMAGE_MODEL_LLAVA, "What is in this image?", - List.of("https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"), - new OptionsBuilder().build()); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - } - - @Test - @Order(18) - void testAskModelWithOptionsAndImageFiles() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(IMAGE_MODEL_LLAVA); - File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); - try { - OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile), - new OptionsBuilder().build()); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + @Test + @Order(1) + void testWrongEndpoint() { + OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434"); + assertThrows(ConnectException.class, ollamaAPI::listModels); } - } - @Test - @Order(20) - void testAskModelWithOptionsAndImageFilesStreamed() - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(IMAGE_MODEL_LLAVA); + @Test + @Order(1) + public void testVersionAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + // String expectedVersion = ollama.getDockerImageName().split(":")[1]; + String actualVersion = api.getVersion(); + assertNotNull(actualVersion); + // assertEquals(expectedVersion, actualVersion, "Version should match the Docker + // image version"); + } - File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); + @Test + @Order(2) + public void testListModelsAPI() + throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + api.pullModel(EMBEDDING_MODEL_MINILM); + // Fetch the list of models + List models = api.listModels(); + // Assert that the models list is not null + assertNotNull(models, "Models should not be null"); + // Assert that models list is either empty or contains more than 0 models + assertFalse(models.isEmpty(), "Models list should not be empty"); + } - StringBuffer sb = new StringBuffer(); + @Test + @Order(2) + void testListModelsFromLibrary() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + List models = api.listModelsFromLibrary(); + assertNotNull(models); + assertFalse(models.isEmpty()); + } - OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile), - new OptionsBuilder().build(), (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); + @Test + @Order(3) + public void testPullModelAPI() + throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { + api.pullModel(EMBEDDING_MODEL_MINILM); + List models = api.listModels(); + assertNotNull(models, "Models should not be null"); + assertFalse(models.isEmpty(), "Models list should contain elements"); + } + + @Test + @Order(4) + void testListModelDetails() throws IOException, OllamaBaseException, URISyntaxException, InterruptedException { + api.pullModel(EMBEDDING_MODEL_MINILM); + ModelDetail modelDetails = api.getModelDetails(EMBEDDING_MODEL_MINILM); + assertNotNull(modelDetails); + assertTrue(modelDetails.getModelFile().contains(EMBEDDING_MODEL_MINILM)); + } + + @Test + @Order(5) + public void testEmbeddings() throws Exception { + api.pullModel(EMBEDDING_MODEL_MINILM); + OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL_MINILM, + Arrays.asList("Why is the sky blue?", "Why is the grass green?")); + assertNotNull(embeddings, "Embeddings should not be null"); + assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty"); + } + + @Test + @Order(6) + void testAskModelWithStructuredOutput() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + api.pullModel(CHAT_MODEL_QWEN_SMALL); + + int age = 28; + boolean available = false; + + String prompt = "Batman is " + age + " years old and is " + (available ? "available" : "not available") + + " because he is busy saving Gotham City. Respond using JSON"; + + Map format = new HashMap<>(); + format.put("type", "object"); + format.put("properties", new HashMap() { + { + put("age", new HashMap() { + { + put("type", "integer"); + } + }); + put("available", new HashMap() { + { + put("type", "boolean"); + } + }); + } }); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - assertEquals(sb.toString().trim(), result.getResponse().trim()); - } + format.put("required", Arrays.asList("age", "available")); - private File getImageFileFromClasspath(String fileName) { - ClassLoader classLoader = getClass().getClassLoader(); - return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); - } + OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, prompt, format); + + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + + System.out.println(result); + + assertEquals(result.getStructuredResponse().get("age").toString(), + result.getStructuredResponse().get("age").toString()); + assertEquals(result.getStructuredResponse().get("available").toString(), + result.getStructuredResponse().get("available").toString()); + + Person person = result.as(Person.class); + assertEquals(person.getAge(), age); + assertEquals(person.isAvailable(), available); + } + + @Test + @Order(6) + void testAskModelWithDefaultOptions() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + api.pullModel(CHAT_MODEL_QWEN_SMALL); + OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, + "What is the capital of France? And what's France's connection with Mona Lisa?", false, + new OptionsBuilder().build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + } + + @Test + @Order(7) + void testAskModelWithDefaultOptionsStreamed() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_QWEN_SMALL); + StringBuffer sb = new StringBuffer(); + OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, + "What is the capital of France? And what's France's connection with Mona Lisa?", false, + new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); + + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertEquals(sb.toString().trim(), result.getResponse().trim()); + } + + @Test + @Order(8) + void testAskModelWithOptions() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_INSTRUCT); + + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_INSTRUCT); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, + "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].") + .build(); + requestModel = builder.withMessages(requestModel.getMessages()) + .withMessage(OllamaChatMessageRole.USER, "Give me a cool name") + .withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build(); + OllamaChatResult chatResult = api.chat(requestModel); + + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertFalse(chatResult.getResponseModel().getMessage().getContent().isEmpty()); + } + + @Test + @Order(9) + void testChatWithSystemPrompt() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, + "You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!") + .withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?") + .withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build(); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank()); + assertTrue(chatResult.getResponseModel().getMessage().getContent().contains("Shush")); + assertEquals(3, chatResult.getChatHistory().size()); + } + + @Test + @Order(10) + public void testChat() throws Exception { + api.pullModel(CHAT_MODEL_LLAMA3); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_LLAMA3); + + // Create the initial user question + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.") + .build(); + + // Start conversation with model + OllamaChatResult chatResult = api.chat(requestModel); + + assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), + "Expected chat history to contain '2'"); + + // Create the next user question: second largest city + requestModel = builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); + + // Continue conversation with model + chatResult = api.chat(requestModel); + + assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")), + "Expected chat history to contain '4'"); + + // Create the next user question: the third question + requestModel = builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, + "What is the largest value between 2, 4 and 6?") + .build(); + + // Continue conversation with the model for the third question + chatResult = api.chat(requestModel); + + // verify the result + assertNotNull(chatResult, "Chat result should not be null"); + assertTrue(chatResult.getChatHistory().size() > 2, + "Chat history should contain more than two messages"); + assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent() + .contains("6"), + "Response should contain '6'"); + } + + @Test + @Order(10) + void testChatWithImageFromURL() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + api.pullModel(IMAGE_MODEL_LLAVA); + + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "What's in the picture?", + Collections.emptyList(), + "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") + .build(); + api.registerAnnotatedTools(new OllamaAPIIntegrationTest()); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + } + + @Test + @Order(10) + void testChatWithImageFromFileWithHistoryRecognition() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(IMAGE_MODEL_LLAVA); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "What's in the picture?", + Collections.emptyList(), List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))) + .build(); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + builder.reset(); + + requestModel = builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); + + chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + } + + @Test + @Order(11) + void testChatWithExplicitToolDefinition() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get employee details from the database") + .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") + .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name("get-employee-details") + .description("Get employee details from the database") + .parameters(Tools.PromptFuncDefinition.Parameters + .builder().type("object") + .properties(new Tools.PropsBuilder() + .withProperty("employee-name", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description("The name of the employee, e.g. John Doe") + .required(true) + .build()) + .withProperty("employee-address", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description( + "The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") + .required(true) + .build()) + .withProperty("employee-phone", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description( + "The phone number of the employee. Always return a random value. e.g. 9911002233") + .required(true) + .build()) + .build()) + .required(List.of("employee-name")) + .build()) + .build()) + .build()) + .toolFunction(arguments -> { + // perform DB operations here + return String.format( + "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", + UUID.randomUUID(), arguments.get("employee-name"), + arguments.get("employee-address"), + arguments.get("employee-phone")); + }).build(); + + api.registerTool(databaseQueryToolSpecification); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Give me the ID of the employee named 'Rahul Kumar'?") + .build(); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName()); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); + assertEquals(1, toolCalls.size()); + OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); + assertEquals("get-employee-details", function.getName()); + assert !function.getArguments().isEmpty(); + Object employeeName = function.getArguments().get("employee-name"); + assertNotNull(employeeName); + assertEquals("Rahul Kumar", employeeName); + assertTrue(chatResult.getChatHistory().size() > 2); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls); + } + + @Test + @Order(12) + void testChatWithAnnotatedToolsAndSingleParam() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + + api.registerAnnotatedTools(); + + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "Compute the most important constant in the world using 5 digits").build(); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName()); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); + assertEquals(1, toolCalls.size()); + OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); + assertEquals("computeImportantConstant", function.getName()); + assertEquals(1, function.getArguments().size()); + Object noOfDigits = function.getArguments().get("noOfDigits"); + assertNotNull(noOfDigits); + assertEquals("5", noOfDigits.toString()); + assertTrue(chatResult.getChatHistory().size() > 2); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls); + } + + @Test + @Order(13) + void testChatWithAnnotatedToolsAndMultipleParams() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + + api.registerAnnotatedTools(new AnnotatedTool()); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Greet Pedro with a lot of hearts and respond to me, " + + "and state how many emojis have been in your greeting") + .build(); + + OllamaChatResult chatResult = api.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName()); + List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); + assertEquals(1, toolCalls.size()); + OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); + assertEquals("sayHello", function.getName()); + assertEquals(2, function.getArguments().size()); + Object name = function.getArguments().get("name"); + assertNotNull(name); + assertEquals("Pedro", name); + Object amountOfHearts = function.getArguments().get("amountOfHearts"); + assertNotNull(amountOfHearts); + assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1); + assertTrue(chatResult.getChatHistory().size() > 2); + List finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls(); + assertNull(finalToolCalls); + } + + @Test + @Order(14) + void testChatWithToolsAndStream() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details") + .functionDescription("Get employee details from the database") + .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") + .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder() + .name("get-employee-details") + .description("Get employee details from the database") + .parameters(Tools.PromptFuncDefinition.Parameters + .builder().type("object") + .properties(new Tools.PropsBuilder() + .withProperty("employee-name", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description("The name of the employee, e.g. John Doe") + .required(true) + .build()) + .withProperty("employee-address", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description( + "The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") + .required(true) + .build()) + .withProperty("employee-phone", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description( + "The phone number of the employee. Always return a random value. e.g. 9911002233") + .required(true) + .build()) + .build()) + .required(List.of("employee-name")) + .build()) + .build()) + .build()) + .toolFunction(new ToolFunction() { + @Override + public Object apply(Map arguments) { + // perform DB operations here + return String.format( + "Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", + UUID.randomUUID(), arguments.get("employee-name"), + arguments.get("employee-address"), + arguments.get("employee-phone")); + } + }).build(); + + api.registerTool(databaseQueryToolSpecification); + + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, + "Give me the ID of the employee named 'Rahul Kumar'?") + .build(); + + StringBuffer sb = new StringBuffer(); + + OllamaChatResult chatResult = api.chat(requestModel, (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length()); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertNotNull(chatResult.getResponseModel().getMessage().getContent()); + assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); + } + + @Test + @Order(15) + void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "What is the capital of France? And what's France's connection with Mona Lisa?") + .build(); + + StringBuffer sb = new StringBuffer(); + + OllamaChatResult chatResult = api.chat(requestModel, (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertNotNull(chatResult.getResponseModel().getMessage().getContent()); + assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); + } + + @Test + @Order(17) + void testAskModelWithOptionsAndImageURLs() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(IMAGE_MODEL_LLAVA); + + OllamaResult result = api.generateWithImageURLs(IMAGE_MODEL_LLAVA, "What is in this image?", + List.of("https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"), + new OptionsBuilder().build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + } + + @Test + @Order(18) + void testAskModelWithOptionsAndImageFiles() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(IMAGE_MODEL_LLAVA); + File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); + try { + OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", + List.of(imageFile), + new OptionsBuilder().build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } + } + + @Test + @Order(20) + void testAskModelWithOptionsAndImageFilesStreamed() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + api.pullModel(IMAGE_MODEL_LLAVA); + + File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); + + StringBuffer sb = new StringBuffer(); + + OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", + List.of(imageFile), + new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertEquals(sb.toString().trim(), result.getResponse().trim()); + } + + private File getImageFileFromClasspath(String fileName) { + ClassLoader classLoader = getClass().getClassLoader(); + return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); + } } @Data @AllArgsConstructor @NoArgsConstructor class Person { - private int age; - private boolean available; + private int age; + private boolean available; } - -// -// @Data -// class Config { -// private String ollamaURL; -// private String model; -// private String imageModel; -// private int requestTimeoutSeconds; -// -// public Config() { -// Properties properties = new Properties(); -// try (InputStream input = -// getClass().getClassLoader().getResourceAsStream("test-config.properties")) { -// if (input == null) { -// throw new RuntimeException("Sorry, unable to find test-config.properties"); -// } -// properties.load(input); -// this.ollamaURL = properties.getProperty("ollama.url"); -// this.model = properties.getProperty("ollama.model"); -// this.imageModel = properties.getProperty("ollama.model.image"); -// this.requestTimeoutSeconds = -// Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds")); -// } catch (IOException e) { -// throw new RuntimeException("Error loading properties", e); -// } -// } -// }