diff --git a/Makefile b/Makefile index 34972dd..1350698 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,9 @@ dev: pre-commit install --install-hooks build: + mvn -B clean install -Dgpg.skip=true + +full-build: mvn -B clean install unit-tests: diff --git a/docs/docs/apis-generate/generate.md b/docs/docs/apis-generate/generate.md index 1cd6a47..c3a2d05 100644 --- a/docs/docs/apis-generate/generate.md +++ b/docs/docs/apis-generate/generate.md @@ -13,7 +13,7 @@ with [extra parameters](https://github.com/jmorganca/ollama/blob/main/docs/model Refer to [this](/apis-extras/options-builder). -## Try asking a question about the model. +## Try asking a question about the model ```java import io.github.ollama4j.OllamaAPI; @@ -87,7 +87,7 @@ You will get a response similar to: > The capital of France is Paris. > Full response: The capital of France is Paris. -## Try asking a question from general topics. +## Try asking a question from general topics ```java import io.github.ollama4j.OllamaAPI; @@ -135,7 +135,7 @@ You'd then get a response from the model: > semi-finals. The tournament was > won by the England cricket team, who defeated New Zealand in the final. -## Try asking for a Database query for your data schema. +## Try asking for a Database query for your data schema ```java import io.github.ollama4j.OllamaAPI; @@ -161,6 +161,7 @@ public class Main { ``` + _Note: Here I've used a [sample prompt](https://github.com/ollama4j/ollama4j/blob/main/src/main/resources/sample-db-prompt-template.txt) containing a database schema from within this library for demonstration purposes._ @@ -172,4 +173,123 @@ SELECT customers.name FROM sales JOIN customers ON sales.customer_id = customers.customer_id GROUP BY customers.name; +``` + + +## Generate structured output + +### With response as a `Map` + +```java +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import io.github.ollama4j.OllamaAPI; +import io.github.ollama4j.utils.Utilities; +import io.github.ollama4j.models.chat.OllamaChatMessageRole; +import io.github.ollama4j.models.chat.OllamaChatRequest; +import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; +import io.github.ollama4j.models.chat.OllamaChatResult; +import io.github.ollama4j.models.response.OllamaResult; +import io.github.ollama4j.types.OllamaModelType; + +public class StructuredOutput { + + public static void main(String[] args) throws Exception { + String host = "http://localhost:11434/"; + + OllamaAPI api = new OllamaAPI(host); + + String chatModel = "qwen2.5:0.5b"; + api.pullModel(chatModel); + + String prompt = "Ollama is 22 years old and is busy saving the world. Respond using JSON"; + Map format = new HashMap<>(); + format.put("type", "object"); + format.put("properties", new HashMap() { + { + put("age", new HashMap() { + { + put("type", "integer"); + } + }); + put("available", new HashMap() { + { + put("type", "boolean"); + } + }); + } + }); + format.put("required", Arrays.asList("age", "available")); + + OllamaResult result = api.generate(chatModel, prompt, format); + System.out.println(result); + } +} +``` + +### With response mapped to specified class type + +```java +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import io.github.ollama4j.OllamaAPI; +import io.github.ollama4j.utils.Utilities; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import io.github.ollama4j.models.chat.OllamaChatMessageRole; +import io.github.ollama4j.models.chat.OllamaChatRequest; +import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; +import io.github.ollama4j.models.chat.OllamaChatResult; +import io.github.ollama4j.models.response.OllamaResult; +import io.github.ollama4j.types.OllamaModelType; + +public class StructuredOutput { + + public static void main(String[] args) throws Exception { + String host = Utilities.getFromConfig("host"); + + OllamaAPI api = new OllamaAPI(host); + + String chatModel = "llama3.1:8b"; + chatModel = "qwen2.5:0.5b"; + api.pullModel(chatModel); + + String prompt = "Ollama is 22 years old and is busy saving the world. Respond using JSON"; + Map format = new HashMap<>(); + format.put("type", "object"); + format.put("properties", new HashMap() { + { + put("age", new HashMap() { + { + put("type", "integer"); + } + }); + put("available", new HashMap() { + { + put("type", "boolean"); + } + }); + } + }); + format.put("required", Arrays.asList("age", "available")); + + OllamaResult result = api.generate(chatModel, prompt, format); + Person person = result.getStructuredResponse(Person.class); + System.out.println(person); + } + +} + +@Data +@AllArgsConstructor +@NoArgsConstructor +class Person { + private int age; + private boolean available; +} ``` \ No newline at end of file diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 5229fe9..955d4dd 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -743,12 +743,12 @@ public class OllamaAPI { * the response. * @param prompt The input text or prompt to provide to the AI model. * @param format A map containing the format specification for the structured output. - * @return An instance of {@link OllamaStructuredResult} containing the structured response. + * @return An instance of {@link OllamaResult} containing the structured response. * @throws OllamaBaseException if the response indicates an error status. * @throws IOException if an I/O error occurs during the HTTP request. * @throws InterruptedException if the operation is interrupted. */ - public OllamaStructuredResult generate(String model, String prompt, Map format) + public OllamaResult generate(String model, String prompt, Map format) throws OllamaBaseException, IOException, InterruptedException { URI uri = URI.create(this.host + "/api/generate"); @@ -771,7 +771,7 @@ public class OllamaAPI { String responseBody = response.body(); if (statusCode == 200) { - return Utils.getObjectMapper().readValue(responseBody, OllamaStructuredResult.class); + return Utils.getObjectMapper().readValue(responseBody, OllamaResult.class); } else { throw new OllamaBaseException(statusCode + " - " + responseBody); } diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java index beb01ec..d64ba32 100644 --- a/src/main/java/io/github/ollama4j/models/response/OllamaResult.java +++ b/src/main/java/io/github/ollama4j/models/response/OllamaResult.java @@ -2,26 +2,36 @@ package io.github.ollama4j.models.response; import static io.github.ollama4j.utils.Utils.getObjectMapper; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; + import lombok.Data; import lombok.Getter; +import lombok.NoArgsConstructor; /** The type Ollama result. */ @Getter @SuppressWarnings("unused") @Data +@NoArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) public class OllamaResult { /** * -- GETTER -- - * Get the completion/response text + * Get the completion/response text * * @return String completion/response text */ - private final String response; + private String response; /** * -- GETTER -- - * Get the response status code. + * Get the response status code. * * @return int - response status code */ @@ -29,12 +39,25 @@ public class OllamaResult { /** * -- GETTER -- - * Get the response time in milliseconds. + * Get the response time in milliseconds. * * @return long - response time in milliseconds */ private long responseTime = 0; + /** + * -- GETTER -- + * Get the model name used for the response. + * + * @return String - model name + */ + private String model; + + @JsonCreator + public OllamaResult(@JsonProperty("response") String response) { + this.response = response; + } + public OllamaResult(String response, long responseTime, int httpStatusCode) { this.response = response; this.responseTime = responseTime; @@ -49,4 +72,36 @@ public class OllamaResult { throw new RuntimeException(e); } } + + /** + * Get the structured response if the response is a JSON object. + * + * @return Map - structured response + */ + public Map getStructuredResponse() { + try { + Map response = getObjectMapper().readValue(this.getResponse(), + new TypeReference>() { + }); + return response; + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + /** + * Get the structured response mapped to a specific class type. + * + * @param The type of class to map the response to + * @param clazz The class to map the response to + * @return An instance of the specified class with the response data + * @throws RuntimeException if there is an error mapping the response + */ + public T getStructuredResponse(Class clazz) { + try { + return getObjectMapper().readValue(this.getResponse(), clazz); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } } diff --git a/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java b/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java deleted file mode 100644 index 42b8e4e..0000000 --- a/src/main/java/io/github/ollama4j/models/response/OllamaStructuredResult.java +++ /dev/null @@ -1,22 +0,0 @@ -package io.github.ollama4j.models.response; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; -import lombok.Data; - -/** - * Structured response for Ollama API - */ -@Data -@JsonIgnoreProperties(ignoreUnknown = true) -public class OllamaStructuredResult { - - @JsonProperty("response") - private String response; - - @JsonProperty("httpStatusCode") - private int httpStatusCode; - - @JsonProperty("responseTime") - private long responseTime; -} diff --git a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java index 1d4c864..1d2078b 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java +++ b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java @@ -11,6 +11,10 @@ import io.github.ollama4j.tools.ToolFunction; import io.github.ollama4j.tools.Tools; import io.github.ollama4j.tools.annotations.OllamaToolService; import io.github.ollama4j.utils.OptionsBuilder; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; import org.junit.jupiter.api.Order; @@ -31,7 +35,7 @@ import java.util.*; import static io.github.ollama4j.utils.Utils.getObjectMapper; import static org.junit.jupiter.api.Assertions.*; -@OllamaToolService(providers = {AnnotatedTool.class}) +@OllamaToolService(providers = { AnnotatedTool.class }) @TestMethodOrder(OrderAnnotation.class) @SuppressWarnings("HttpUrlsUsage") @@ -116,17 +120,21 @@ public class OllamaAPIIntegrationTest { public void testEmbeddings() throws Exception { String embeddingModelMinilm = "all-minilm"; api.pullModel(embeddingModelMinilm); - OllamaEmbedResponseModel embeddings = api.embed(embeddingModelMinilm, Arrays.asList("Why is the sky blue?", "Why is the grass green?")); + OllamaEmbedResponseModel embeddings = api.embed(embeddingModelMinilm, + Arrays.asList("Why is the sky blue?", "Why is the grass green?")); assertNotNull(embeddings, "Embeddings should not be null"); assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty"); } @Test @Order(6) - void testAskModelWithDefaultOptions() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + void testAskModelWithDefaultOptions() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String chatModel = "qwen2.5:0.5b"; api.pullModel(chatModel); - OllamaResult result = api.generate(chatModel, "What is the capital of France? And what's France's connection with Mona Lisa?", false, new OptionsBuilder().build()); + OllamaResult result = api.generate(chatModel, + "What is the capital of France? And what's France's connection with Mona Lisa?", false, + new OptionsBuilder().build()); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); @@ -134,7 +142,8 @@ public class OllamaAPIIntegrationTest { @Test @Order(6) - void testAskModelWithStructuredOutput() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + void testAskModelWithStructuredOutput() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String chatModel = "llama3.1:8b"; chatModel = "qwen2.5:0.5b"; api.pullModel(chatModel); @@ -142,17 +151,23 @@ public class OllamaAPIIntegrationTest { String prompt = "Ollama is 22 years old and is busy saving the world. Respond using JSON"; Map format = new HashMap<>(); format.put("type", "object"); - format.put("properties", new HashMap() {{ - put("age", new HashMap() {{ - put("type", "integer"); - }}); - put("available", new HashMap() {{ - put("type", "boolean"); - }}); - }}); + format.put("properties", new HashMap() { + { + put("age", new HashMap() { + { + put("type", "integer"); + } + }); + put("available", new HashMap() { + { + put("type", "boolean"); + } + }); + } + }); format.put("required", Arrays.asList("age", "available")); - OllamaStructuredResult result = api.generate(chatModel, prompt, format); + OllamaResult result = api.generate(chatModel, prompt, format); assertNotNull(result); assertNotNull(result.getResponse()); @@ -161,25 +176,41 @@ public class OllamaAPIIntegrationTest { Map actualResponse = getObjectMapper().readValue(result.getResponse(), new TypeReference<>() { }); - String expectedResponseJson = "{\n \"age\": 22,\n \"available\": true\n}"; - Map expectedResponse = getObjectMapper().readValue(expectedResponseJson, new TypeReference>() { - }); + int age = 22; + boolean available = true; + String expectedResponseJson = "{\n \"age\": " + age + ",\n \"available\": " + available + "\n}"; + + Map expectedResponse = getObjectMapper().readValue(expectedResponseJson, + new TypeReference>() { + }); assertEquals(actualResponse.get("age").toString(), expectedResponse.get("age").toString()); assertEquals(actualResponse.get("available").toString(), expectedResponse.get("available").toString()); + + assertEquals(result.getStructuredResponse().get("age").toString(), + result.getStructuredResponse().get("age").toString()); + assertEquals(result.getStructuredResponse().get("available").toString(), + result.getStructuredResponse().get("available").toString()); + + Person person = result.getStructuredResponse(Person.class); + assertEquals(person.getAge(), age); + assertEquals(person.isAvailable(), available); } @Test @Order(7) - void testAskModelWithDefaultOptionsStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testAskModelWithDefaultOptionsStreamed() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String chatModel = "qwen2.5:0.5b"; api.pullModel(chatModel); StringBuffer sb = new StringBuffer(); - OllamaResult result = api.generate(chatModel, "What is the capital of France? And what's France's connection with Mona Lisa?", false, new OptionsBuilder().build(), (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); + OllamaResult result = api.generate(chatModel, + "What is the capital of France? And what's France's connection with Mona Lisa?", false, + new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); assertNotNull(result); assertNotNull(result.getResponse()); @@ -194,8 +225,12 @@ public class OllamaAPIIntegrationTest { api.pullModel(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].").build(); - requestModel = builder.withMessages(requestModel.getMessages()).withMessage(OllamaChatMessageRole.USER, "Give me a cool name").withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, + "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].") + .build(); + requestModel = builder.withMessages(requestModel.getMessages()) + .withMessage(OllamaChatMessageRole.USER, "Give me a cool name") + .withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); @@ -209,7 +244,10 @@ public class OllamaAPIIntegrationTest { String chatModel = "llama3.2:1b"; api.pullModel(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!").withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?").withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, + "You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!") + .withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?") + .withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); @@ -228,23 +266,28 @@ public class OllamaAPIIntegrationTest { OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); // Create the initial user question - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.").build(); + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.").build(); // Start conversation with model OllamaChatResult chatResult = api.chat(requestModel); - assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), "Expected chat history to contain '2'"); + assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), + "Expected chat history to contain '2'"); // Create the next user question: second largest city - requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); + requestModel = builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); // Continue conversation with model chatResult = api.chat(requestModel); - assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")), "Expected chat history to contain '4'"); + assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")), + "Expected chat history to contain '4'"); // Create the next user question: the third question - requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "What is the largest value between 2, 4 and 6?").build(); + requestModel = builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, "What is the largest value between 2, 4 and 6?").build(); // Continue conversation with the model for the third question chatResult = api.chat(requestModel); @@ -252,7 +295,8 @@ public class OllamaAPIIntegrationTest { // verify the result assertNotNull(chatResult, "Chat result should not be null"); assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should contain more than two messages"); - assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent().contains("6"), "Response should contain '6'"); + assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent().contains("6"), + "Response should contain '6'"); } @Test @@ -262,7 +306,10 @@ public class OllamaAPIIntegrationTest { api.pullModel(imageModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(imageModel); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg").build(); + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), + "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") + .build(); api.registerAnnotatedTools(new OllamaAPIIntegrationTest()); OllamaChatResult chatResult = api.chat(requestModel); @@ -271,18 +318,21 @@ public class OllamaAPIIntegrationTest { @Test @Order(10) - void testChatWithImageFromFileWithHistoryRecognition() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testChatWithImageFromFileWithHistoryRecognition() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String imageModel = "moondream"; api.pullModel(imageModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(imageModel); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", + Collections.emptyList(), List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); builder.reset(); - requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); + requestModel = builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); chatResult = api.chat(requestModel); assertNotNull(chatResult); @@ -291,25 +341,55 @@ public class OllamaAPIIntegrationTest { @Test @Order(11) - void testChatWithExplicitToolDefinition() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testChatWithExplicitToolDefinition() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String chatModel = "llama3.2:1b"; api.pullModel(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder().functionName("get-employee-details").functionDescription("Get employee details from the database").toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details").description("Get employee details from the database").parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(new Tools.PropsBuilder().withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()).withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()).withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()).build()).required(List.of("employee-name")).build()).build()).build()).toolFunction(arguments -> { - // perform DB operations here - return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone")); - }).build(); + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details").functionDescription("Get employee details from the database") + .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") + .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details") + .description("Get employee details from the database") + .parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object") + .properties(new Tools.PropsBuilder() + .withProperty("employee-name", + Tools.PromptFuncDefinition.Property.builder().type("string") + .description("The name of the employee, e.g. John Doe") + .required(true).build()) + .withProperty("employee-address", Tools.PromptFuncDefinition.Property + .builder().type("string") + .description( + "The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") + .required(true).build()) + .withProperty("employee-phone", Tools.PromptFuncDefinition.Property + .builder().type("string") + .description( + "The phone number of the employee. Always return a random value. e.g. 9911002233") + .required(true).build()) + .build()) + .required(List.of("employee-name")).build()) + .build()) + .build()) + .toolFunction(arguments -> { + // perform DB operations here + return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", + UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), + arguments.get("employee-phone")); + }).build(); api.registerTool(databaseQueryToolSpecification); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build(); + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName()); List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); assertEquals(1, toolCalls.size()); OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); @@ -325,7 +405,8 @@ public class OllamaAPIIntegrationTest { @Test @Order(12) - void testChatWithAnnotatedToolsAndSingleParam() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + void testChatWithAnnotatedToolsAndSingleParam() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String chatModel = "llama3.2:1b"; api.pullModel(chatModel); @@ -333,13 +414,15 @@ public class OllamaAPIIntegrationTest { api.registerAnnotatedTools(); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Compute the most important constant in the world using 5 digits").build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "Compute the most important constant in the world using 5 digits").build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName()); List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); assertEquals(1, toolCalls.size()); OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); @@ -355,20 +438,25 @@ public class OllamaAPIIntegrationTest { @Test @Order(13) - void testChatWithAnnotatedToolsAndMultipleParams() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testChatWithAnnotatedToolsAndMultipleParams() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String chatModel = "llama3.2:1b"; api.pullModel(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); api.registerAnnotatedTools(new AnnotatedTool()); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Greet Pedro with a lot of hearts and respond to me, " + "and state how many emojis have been in your greeting").build(); + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "Greet Pedro with a lot of hearts and respond to me, " + + "and state how many emojis have been in your greeting") + .build(); OllamaChatResult chatResult = api.chat(requestModel); assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); - assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); + assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), + chatResult.getResponseModel().getMessage().getRole().getRoleName()); List toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); assertEquals(1, toolCalls.size()); OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); @@ -387,21 +475,50 @@ public class OllamaAPIIntegrationTest { @Test @Order(14) - void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testChatWithToolsAndStream() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String chatModel = "llama3.2:1b"; api.pullModel(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder().functionName("get-employee-details").functionDescription("Get employee details from the database").toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details").description("Get employee details from the database").parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(new Tools.PropsBuilder().withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()).withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()).withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()).build()).required(List.of("employee-name")).build()).build()).build()).toolFunction(new ToolFunction() { - @Override - public Object apply(Map arguments) { - // perform DB operations here - return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone")); - } - }).build(); + final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder() + .functionName("get-employee-details").functionDescription("Get employee details from the database") + .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") + .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details") + .description("Get employee details from the database") + .parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object") + .properties(new Tools.PropsBuilder() + .withProperty("employee-name", + Tools.PromptFuncDefinition.Property.builder().type("string") + .description("The name of the employee, e.g. John Doe") + .required(true).build()) + .withProperty("employee-address", Tools.PromptFuncDefinition.Property + .builder().type("string") + .description( + "The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India") + .required(true).build()) + .withProperty("employee-phone", Tools.PromptFuncDefinition.Property + .builder().type("string") + .description( + "The phone number of the employee. Always return a random value. e.g. 9911002233") + .required(true).build()) + .build()) + .required(List.of("employee-name")).build()) + .build()) + .build()) + .toolFunction(new ToolFunction() { + @Override + public Object apply(Map arguments) { + // perform DB operations here + return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", + UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), + arguments.get("employee-phone")); + } + }).build(); api.registerTool(databaseQueryToolSpecification); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build(); + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build(); StringBuffer sb = new StringBuffer(); @@ -424,7 +541,8 @@ public class OllamaAPIIntegrationTest { String chatModel = "llama3.2:1b"; api.pullModel(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); - OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?").build(); + OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "What is the capital of France? And what's France's connection with Mona Lisa?").build(); StringBuffer sb = new StringBuffer(); @@ -441,14 +559,16 @@ public class OllamaAPIIntegrationTest { assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); } - @Test @Order(17) - void testAskModelWithOptionsAndImageURLs() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testAskModelWithOptionsAndImageURLs() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String imageModel = "llava"; api.pullModel(imageModel); - OllamaResult result = api.generateWithImageURLs(imageModel, "What is in this image?", List.of("https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"), new OptionsBuilder().build()); + OllamaResult result = api.generateWithImageURLs(imageModel, "What is in this image?", + List.of("https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"), + new OptionsBuilder().build()); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); @@ -456,12 +576,14 @@ public class OllamaAPIIntegrationTest { @Test @Order(18) - void testAskModelWithOptionsAndImageFiles() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testAskModelWithOptionsAndImageFiles() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String imageModel = "llava"; api.pullModel(imageModel); File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); try { - OllamaResult result = api.generateWithImageFiles(imageModel, "What is in this image?", List.of(imageFile), new OptionsBuilder().build()); + OllamaResult result = api.generateWithImageFiles(imageModel, "What is in this image?", List.of(imageFile), + new OptionsBuilder().build()); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); @@ -470,10 +592,10 @@ public class OllamaAPIIntegrationTest { } } - @Test @Order(20) - void testAskModelWithOptionsAndImageFilesStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + void testAskModelWithOptionsAndImageFilesStreamed() + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String imageModel = "llava"; api.pullModel(imageModel); @@ -481,12 +603,13 @@ public class OllamaAPIIntegrationTest { StringBuffer sb = new StringBuffer(); - OllamaResult result = api.generateWithImageFiles(imageModel, "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); + OllamaResult result = api.generateWithImageFiles(imageModel, "What is in this image?", List.of(imageFile), + new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); @@ -498,29 +621,38 @@ public class OllamaAPIIntegrationTest { return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); } } + +@Data +@AllArgsConstructor +@NoArgsConstructor +class Person { + private int age; + private boolean available; +} + // -//@Data -//class Config { -// private String ollamaURL; -// private String model; -// private String imageModel; -// private int requestTimeoutSeconds; +// @Data +// class Config { +// private String ollamaURL; +// private String model; +// private String imageModel; +// private int requestTimeoutSeconds; // -// public Config() { -// Properties properties = new Properties(); -// try (InputStream input = -// getClass().getClassLoader().getResourceAsStream("test-config.properties")) { -// if (input == null) { -// throw new RuntimeException("Sorry, unable to find test-config.properties"); -// } -// properties.load(input); -// this.ollamaURL = properties.getProperty("ollama.url"); -// this.model = properties.getProperty("ollama.model"); -// this.imageModel = properties.getProperty("ollama.model.image"); -// this.requestTimeoutSeconds = -// Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds")); -// } catch (IOException e) { -// throw new RuntimeException("Error loading properties", e); -// } -// } -//} +// public Config() { +// Properties properties = new Properties(); +// try (InputStream input = +// getClass().getClassLoader().getResourceAsStream("test-config.properties")) { +// if (input == null) { +// throw new RuntimeException("Sorry, unable to find test-config.properties"); +// } +// properties.load(input); +// this.ollamaURL = properties.getProperty("ollama.url"); +// this.model = properties.getProperty("ollama.model"); +// this.imageModel = properties.getProperty("ollama.model.image"); +// this.requestTimeoutSeconds = +// Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds")); +// } catch (IOException e) { +// throw new RuntimeException("Error loading properties", e); +// } +// } +// }