Enhance OllamaAPI and documentation for structured responses

- Updated OllamaAPI to return an instance of OllamaResult instead of OllamaStructuredResult for structured responses.
- Removed the obsolete OllamaStructuredResult class.
- Added new methods in OllamaResult for retrieving structured responses as a Map or mapped to a specific class type.
- Updated integration tests to validate the new structured response functionality.
- Improved Makefile with a new full-build target for building the project.
This commit is contained in:
Amith Koujalgi 2025-03-24 15:30:00 +05:30
parent 407b7eb280
commit 2d7902167b
No known key found for this signature in database
GPG Key ID: 3F065E7150B71F9D
6 changed files with 417 additions and 129 deletions

View File

@ -7,6 +7,9 @@ dev:
pre-commit install --install-hooks pre-commit install --install-hooks
build: build:
mvn -B clean install -Dgpg.skip=true
full-build:
mvn -B clean install mvn -B clean install
unit-tests: unit-tests:

View File

@ -13,7 +13,7 @@ with [extra parameters](https://github.com/jmorganca/ollama/blob/main/docs/model
Refer Refer
to [this](/apis-extras/options-builder). to [this](/apis-extras/options-builder).
## Try asking a question about the model. ## Try asking a question about the model
```java ```java
import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.OllamaAPI;
@ -87,7 +87,7 @@ You will get a response similar to:
> The capital of France is Paris. > The capital of France is Paris.
> Full response: The capital of France is Paris. > Full response: The capital of France is Paris.
## Try asking a question from general topics. ## Try asking a question from general topics
```java ```java
import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.OllamaAPI;
@ -135,7 +135,7 @@ You'd then get a response from the model:
> semi-finals. The tournament was > semi-finals. The tournament was
> won by the England cricket team, who defeated New Zealand in the final. > won by the England cricket team, who defeated New Zealand in the final.
## Try asking for a Database query for your data schema. ## Try asking for a Database query for your data schema
```java ```java
import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.OllamaAPI;
@ -161,6 +161,7 @@ public class Main {
``` ```
_Note: Here I've used _Note: Here I've used
a [sample prompt](https://github.com/ollama4j/ollama4j/blob/main/src/main/resources/sample-db-prompt-template.txt) a [sample prompt](https://github.com/ollama4j/ollama4j/blob/main/src/main/resources/sample-db-prompt-template.txt)
containing a database schema from within this library for demonstration purposes._ containing a database schema from within this library for demonstration purposes._
@ -172,4 +173,123 @@ SELECT customers.name
FROM sales FROM sales
JOIN customers ON sales.customer_id = customers.customer_id JOIN customers ON sales.customer_id = customers.customer_id
GROUP BY customers.name; GROUP BY customers.name;
```
## Generate structured output
### With response as a `Map`
```java
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.utils.Utilities;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.chat.OllamaChatRequest;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.types.OllamaModelType;
public class StructuredOutput {
public static void main(String[] args) throws Exception {
String host = "http://localhost:11434/";
OllamaAPI api = new OllamaAPI(host);
String chatModel = "qwen2.5:0.5b";
api.pullModel(chatModel);
String prompt = "Ollama is 22 years old and is busy saving the world. Respond using JSON";
Map<String, Object> format = new HashMap<>();
format.put("type", "object");
format.put("properties", new HashMap<String, Object>() {
{
put("age", new HashMap<String, Object>() {
{
put("type", "integer");
}
});
put("available", new HashMap<String, Object>() {
{
put("type", "boolean");
}
});
}
});
format.put("required", Arrays.asList("age", "available"));
OllamaResult result = api.generate(chatModel, prompt, format);
System.out.println(result);
}
}
```
### With response mapped to specified class type
```java
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.utils.Utilities;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.chat.OllamaChatRequest;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.models.response.OllamaResult;
import io.github.ollama4j.types.OllamaModelType;
public class StructuredOutput {
public static void main(String[] args) throws Exception {
String host = Utilities.getFromConfig("host");
OllamaAPI api = new OllamaAPI(host);
String chatModel = "llama3.1:8b";
chatModel = "qwen2.5:0.5b";
api.pullModel(chatModel);
String prompt = "Ollama is 22 years old and is busy saving the world. Respond using JSON";
Map<String, Object> format = new HashMap<>();
format.put("type", "object");
format.put("properties", new HashMap<String, Object>() {
{
put("age", new HashMap<String, Object>() {
{
put("type", "integer");
}
});
put("available", new HashMap<String, Object>() {
{
put("type", "boolean");
}
});
}
});
format.put("required", Arrays.asList("age", "available"));
OllamaResult result = api.generate(chatModel, prompt, format);
Person person = result.getStructuredResponse(Person.class);
System.out.println(person);
}
}
@Data
@AllArgsConstructor
@NoArgsConstructor
class Person {
private int age;
private boolean available;
}
``` ```

View File

@ -743,12 +743,12 @@ public class OllamaAPI {
* the response. * the response.
* @param prompt The input text or prompt to provide to the AI model. * @param prompt The input text or prompt to provide to the AI model.
* @param format A map containing the format specification for the structured output. * @param format A map containing the format specification for the structured output.
* @return An instance of {@link OllamaStructuredResult} containing the structured response. * @return An instance of {@link OllamaResult} containing the structured response.
* @throws OllamaBaseException if the response indicates an error status. * @throws OllamaBaseException if the response indicates an error status.
* @throws IOException if an I/O error occurs during the HTTP request. * @throws IOException if an I/O error occurs during the HTTP request.
* @throws InterruptedException if the operation is interrupted. * @throws InterruptedException if the operation is interrupted.
*/ */
public OllamaStructuredResult generate(String model, String prompt, Map<String, Object> format) public OllamaResult generate(String model, String prompt, Map<String, Object> format)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
@ -771,7 +771,7 @@ public class OllamaAPI {
String responseBody = response.body(); String responseBody = response.body();
if (statusCode == 200) { if (statusCode == 200) {
return Utils.getObjectMapper().readValue(responseBody, OllamaStructuredResult.class); return Utils.getObjectMapper().readValue(responseBody, OllamaResult.class);
} else { } else {
throw new OllamaBaseException(statusCode + " - " + responseBody); throw new OllamaBaseException(statusCode + " - " + responseBody);
} }

View File

@ -2,26 +2,36 @@ package io.github.ollama4j.models.response;
import static io.github.ollama4j.utils.Utils.getObjectMapper; import static io.github.ollama4j.utils.Utils.getObjectMapper;
import java.util.Map;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import lombok.Data; import lombok.Data;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor;
/** The type Ollama result. */ /** The type Ollama result. */
@Getter @Getter
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Data @Data
@NoArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaResult { public class OllamaResult {
/** /**
* -- GETTER -- * -- GETTER --
* Get the completion/response text * Get the completion/response text
* *
* @return String completion/response text * @return String completion/response text
*/ */
private final String response; private String response;
/** /**
* -- GETTER -- * -- GETTER --
* Get the response status code. * Get the response status code.
* *
* @return int - response status code * @return int - response status code
*/ */
@ -29,12 +39,25 @@ public class OllamaResult {
/** /**
* -- GETTER -- * -- GETTER --
* Get the response time in milliseconds. * Get the response time in milliseconds.
* *
* @return long - response time in milliseconds * @return long - response time in milliseconds
*/ */
private long responseTime = 0; private long responseTime = 0;
/**
* -- GETTER --
* Get the model name used for the response.
*
* @return String - model name
*/
private String model;
@JsonCreator
public OllamaResult(@JsonProperty("response") String response) {
this.response = response;
}
public OllamaResult(String response, long responseTime, int httpStatusCode) { public OllamaResult(String response, long responseTime, int httpStatusCode) {
this.response = response; this.response = response;
this.responseTime = responseTime; this.responseTime = responseTime;
@ -49,4 +72,36 @@ public class OllamaResult {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
/**
* Get the structured response if the response is a JSON object.
*
* @return Map - structured response
*/
public Map<String, Object> getStructuredResponse() {
try {
Map<String, Object> response = getObjectMapper().readValue(this.getResponse(),
new TypeReference<Map<String, Object>>() {
});
return response;
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* Get the structured response mapped to a specific class type.
*
* @param <T> The type of class to map the response to
* @param clazz The class to map the response to
* @return An instance of the specified class with the response data
* @throws RuntimeException if there is an error mapping the response
*/
public <T> T getStructuredResponse(Class<T> clazz) {
try {
return getObjectMapper().readValue(this.getResponse(), clazz);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
} }

View File

@ -1,22 +0,0 @@
package io.github.ollama4j.models.response;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
/**
* Structured response for Ollama API
*/
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaStructuredResult {
@JsonProperty("response")
private String response;
@JsonProperty("httpStatusCode")
private int httpStatusCode;
@JsonProperty("responseTime")
private long responseTime;
}

View File

@ -11,6 +11,10 @@ import io.github.ollama4j.tools.ToolFunction;
import io.github.ollama4j.tools.Tools; import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.tools.annotations.OllamaToolService; import io.github.ollama4j.tools.annotations.OllamaToolService;
import io.github.ollama4j.utils.OptionsBuilder; import io.github.ollama4j.utils.OptionsBuilder;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; import org.junit.jupiter.api.MethodOrderer.OrderAnnotation;
import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Order;
@ -31,7 +35,7 @@ import java.util.*;
import static io.github.ollama4j.utils.Utils.getObjectMapper; import static io.github.ollama4j.utils.Utils.getObjectMapper;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@OllamaToolService(providers = {AnnotatedTool.class}) @OllamaToolService(providers = { AnnotatedTool.class })
@TestMethodOrder(OrderAnnotation.class) @TestMethodOrder(OrderAnnotation.class)
@SuppressWarnings("HttpUrlsUsage") @SuppressWarnings("HttpUrlsUsage")
@ -116,17 +120,21 @@ public class OllamaAPIIntegrationTest {
public void testEmbeddings() throws Exception { public void testEmbeddings() throws Exception {
String embeddingModelMinilm = "all-minilm"; String embeddingModelMinilm = "all-minilm";
api.pullModel(embeddingModelMinilm); api.pullModel(embeddingModelMinilm);
OllamaEmbedResponseModel embeddings = api.embed(embeddingModelMinilm, Arrays.asList("Why is the sky blue?", "Why is the grass green?")); OllamaEmbedResponseModel embeddings = api.embed(embeddingModelMinilm,
Arrays.asList("Why is the sky blue?", "Why is the grass green?"));
assertNotNull(embeddings, "Embeddings should not be null"); assertNotNull(embeddings, "Embeddings should not be null");
assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty"); assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty");
} }
@Test @Test
@Order(6) @Order(6)
void testAskModelWithDefaultOptions() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { void testAskModelWithDefaultOptions()
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String chatModel = "qwen2.5:0.5b"; String chatModel = "qwen2.5:0.5b";
api.pullModel(chatModel); api.pullModel(chatModel);
OllamaResult result = api.generate(chatModel, "What is the capital of France? And what's France's connection with Mona Lisa?", false, new OptionsBuilder().build()); OllamaResult result = api.generate(chatModel,
"What is the capital of France? And what's France's connection with Mona Lisa?", false,
new OptionsBuilder().build());
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -134,7 +142,8 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(6) @Order(6)
void testAskModelWithStructuredOutput() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { void testAskModelWithStructuredOutput()
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String chatModel = "llama3.1:8b"; String chatModel = "llama3.1:8b";
chatModel = "qwen2.5:0.5b"; chatModel = "qwen2.5:0.5b";
api.pullModel(chatModel); api.pullModel(chatModel);
@ -142,17 +151,23 @@ public class OllamaAPIIntegrationTest {
String prompt = "Ollama is 22 years old and is busy saving the world. Respond using JSON"; String prompt = "Ollama is 22 years old and is busy saving the world. Respond using JSON";
Map<String, Object> format = new HashMap<>(); Map<String, Object> format = new HashMap<>();
format.put("type", "object"); format.put("type", "object");
format.put("properties", new HashMap<String, Object>() {{ format.put("properties", new HashMap<String, Object>() {
put("age", new HashMap<String, Object>() {{ {
put("type", "integer"); put("age", new HashMap<String, Object>() {
}}); {
put("available", new HashMap<String, Object>() {{ put("type", "integer");
put("type", "boolean"); }
}}); });
}}); put("available", new HashMap<String, Object>() {
{
put("type", "boolean");
}
});
}
});
format.put("required", Arrays.asList("age", "available")); format.put("required", Arrays.asList("age", "available"));
OllamaStructuredResult result = api.generate(chatModel, prompt, format); OllamaResult result = api.generate(chatModel, prompt, format);
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
@ -161,25 +176,41 @@ public class OllamaAPIIntegrationTest {
Map<String, Object> actualResponse = getObjectMapper().readValue(result.getResponse(), new TypeReference<>() { Map<String, Object> actualResponse = getObjectMapper().readValue(result.getResponse(), new TypeReference<>() {
}); });
String expectedResponseJson = "{\n \"age\": 22,\n \"available\": true\n}"; int age = 22;
Map<String, Object> expectedResponse = getObjectMapper().readValue(expectedResponseJson, new TypeReference<Map<String, Object>>() { boolean available = true;
}); String expectedResponseJson = "{\n \"age\": " + age + ",\n \"available\": " + available + "\n}";
Map<String, Object> expectedResponse = getObjectMapper().readValue(expectedResponseJson,
new TypeReference<Map<String, Object>>() {
});
assertEquals(actualResponse.get("age").toString(), expectedResponse.get("age").toString()); assertEquals(actualResponse.get("age").toString(), expectedResponse.get("age").toString());
assertEquals(actualResponse.get("available").toString(), expectedResponse.get("available").toString()); assertEquals(actualResponse.get("available").toString(), expectedResponse.get("available").toString());
assertEquals(result.getStructuredResponse().get("age").toString(),
result.getStructuredResponse().get("age").toString());
assertEquals(result.getStructuredResponse().get("available").toString(),
result.getStructuredResponse().get("available").toString());
Person person = result.getStructuredResponse(Person.class);
assertEquals(person.getAge(), age);
assertEquals(person.isAvailable(), available);
} }
@Test @Test
@Order(7) @Order(7)
void testAskModelWithDefaultOptionsStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testAskModelWithDefaultOptionsStreamed()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String chatModel = "qwen2.5:0.5b"; String chatModel = "qwen2.5:0.5b";
api.pullModel(chatModel); api.pullModel(chatModel);
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaResult result = api.generate(chatModel, "What is the capital of France? And what's France's connection with Mona Lisa?", false, new OptionsBuilder().build(), (s) -> { OllamaResult result = api.generate(chatModel,
LOG.info(s); "What is the capital of France? And what's France's connection with Mona Lisa?", false,
String substring = s.substring(sb.toString().length(), s.length()); new OptionsBuilder().build(), (s) -> {
LOG.info(substring); LOG.info(s);
sb.append(substring); String substring = s.substring(sb.toString().length(), s.length());
}); LOG.info(substring);
sb.append(substring);
});
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
@ -194,8 +225,12 @@ public class OllamaAPIIntegrationTest {
api.pullModel(chatModel); api.pullModel(chatModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].").build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
requestModel = builder.withMessages(requestModel.getMessages()).withMessage(OllamaChatMessageRole.USER, "Give me a cool name").withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build(); "You are a helpful assistant who can generate random person's first and last names in the format [First name, Last name].")
.build();
requestModel = builder.withMessages(requestModel.getMessages())
.withMessage(OllamaChatMessageRole.USER, "Give me a cool name")
.withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
@ -209,7 +244,10 @@ public class OllamaAPIIntegrationTest {
String chatModel = "llama3.2:1b"; String chatModel = "llama3.2:1b";
api.pullModel(chatModel); api.pullModel(chatModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, "You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!").withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?").withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"You are a silent bot that only says 'Shush'. Do not say anything else under any circumstances!")
.withMessage(OllamaChatMessageRole.USER, "What's something that's brown and sticky?")
.withOptions(new OptionsBuilder().setTemperature(0.8f).build()).build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
@ -228,23 +266,28 @@ public class OllamaAPIIntegrationTest {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel);
// Create the initial user question // Create the initial user question
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.").build(); OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.").build();
// Start conversation with model // Start conversation with model
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), "Expected chat history to contain '2'"); assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")),
"Expected chat history to contain '2'");
// Create the next user question: second largest city // Create the next user question: second largest city
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); requestModel = builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build();
// Continue conversation with model // Continue conversation with model
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel);
assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")), "Expected chat history to contain '4'"); assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")),
"Expected chat history to contain '4'");
// Create the next user question: the third question // Create the next user question: the third question
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "What is the largest value between 2, 4 and 6?").build(); requestModel = builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What is the largest value between 2, 4 and 6?").build();
// Continue conversation with the model for the third question // Continue conversation with the model for the third question
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel);
@ -252,7 +295,8 @@ public class OllamaAPIIntegrationTest {
// verify the result // verify the result
assertNotNull(chatResult, "Chat result should not be null"); assertNotNull(chatResult, "Chat result should not be null");
assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should contain more than two messages"); assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should contain more than two messages");
assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent().contains("6"), "Response should contain '6'"); assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent().contains("6"),
"Response should contain '6'");
} }
@Test @Test
@ -262,7 +306,10 @@ public class OllamaAPIIntegrationTest {
api.pullModel(imageModel); api.pullModel(imageModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(imageModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(imageModel);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg").build(); OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(),
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build();
api.registerAnnotatedTools(new OllamaAPIIntegrationTest()); api.registerAnnotatedTools(new OllamaAPIIntegrationTest());
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
@ -271,18 +318,21 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(10) @Order(10)
void testChatWithImageFromFileWithHistoryRecognition() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testChatWithImageFromFileWithHistoryRecognition()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String imageModel = "moondream"; String imageModel = "moondream";
api.pullModel(imageModel); api.pullModel(imageModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(imageModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(imageModel);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(), List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
Collections.emptyList(), List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
builder.reset(); builder.reset();
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); requestModel = builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
@ -291,25 +341,55 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(11) @Order(11)
void testChatWithExplicitToolDefinition() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testChatWithExplicitToolDefinition()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String chatModel = "llama3.2:1b"; String chatModel = "llama3.2:1b";
api.pullModel(chatModel); api.pullModel(chatModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel);
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder().functionName("get-employee-details").functionDescription("Get employee details from the database").toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details").description("Get employee details from the database").parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(new Tools.PropsBuilder().withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()).withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()).withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()).build()).required(List.of("employee-name")).build()).build()).build()).toolFunction(arguments -> { final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
// perform DB operations here .functionName("get-employee-details").functionDescription("Get employee details from the database")
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone")); .toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
}).build(); .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details")
.description("Get employee details from the database")
.parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object")
.properties(new Tools.PropsBuilder()
.withProperty("employee-name",
Tools.PromptFuncDefinition.Property.builder().type("string")
.description("The name of the employee, e.g. John Doe")
.required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property
.builder().type("string")
.description(
"The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
.required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property
.builder().type("string")
.description(
"The phone number of the employee. Always return a random value. e.g. 9911002233")
.required(true).build())
.build())
.required(List.of("employee-name")).build())
.build())
.build())
.toolFunction(arguments -> {
// perform DB operations here
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"),
arguments.get("employee-phone"));
}).build();
api.registerTool(databaseQueryToolSpecification); api.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build(); OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),
chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size()); assertEquals(1, toolCalls.size());
OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
@ -325,7 +405,8 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(12) @Order(12)
void testChatWithAnnotatedToolsAndSingleParam() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { void testChatWithAnnotatedToolsAndSingleParam()
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String chatModel = "llama3.2:1b"; String chatModel = "llama3.2:1b";
api.pullModel(chatModel); api.pullModel(chatModel);
@ -333,13 +414,15 @@ public class OllamaAPIIntegrationTest {
api.registerAnnotatedTools(); api.registerAnnotatedTools();
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Compute the most important constant in the world using 5 digits").build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"Compute the most important constant in the world using 5 digits").build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),
chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size()); assertEquals(1, toolCalls.size());
OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
@ -355,20 +438,25 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(13) @Order(13)
void testChatWithAnnotatedToolsAndMultipleParams() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testChatWithAnnotatedToolsAndMultipleParams()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String chatModel = "llama3.2:1b"; String chatModel = "llama3.2:1b";
api.pullModel(chatModel); api.pullModel(chatModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel);
api.registerAnnotatedTools(new AnnotatedTool()); api.registerAnnotatedTools(new AnnotatedTool());
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Greet Pedro with a lot of hearts and respond to me, " + "and state how many emojis have been in your greeting").build(); OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "Greet Pedro with a lot of hearts and respond to me, "
+ "and state how many emojis have been in your greeting")
.build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(), chatResult.getResponseModel().getMessage().getRole().getRoleName()); assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),
chatResult.getResponseModel().getMessage().getRole().getRoleName());
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls(); List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
assertEquals(1, toolCalls.size()); assertEquals(1, toolCalls.size());
OllamaToolCallsFunction function = toolCalls.get(0).getFunction(); OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
@ -387,21 +475,50 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(14) @Order(14)
void testChatWithToolsAndStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testChatWithToolsAndStream()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String chatModel = "llama3.2:1b"; String chatModel = "llama3.2:1b";
api.pullModel(chatModel); api.pullModel(chatModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel);
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder().functionName("get-employee-details").functionDescription("Get employee details from the database").toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details").description("Get employee details from the database").parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(new Tools.PropsBuilder().withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build()).withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build()).withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build()).build()).required(List.of("employee-name")).build()).build()).build()).toolFunction(new ToolFunction() { final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
@Override .functionName("get-employee-details").functionDescription("Get employee details from the database")
public Object apply(Map<String, Object> arguments) { .toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
// perform DB operations here .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details")
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone")); .description("Get employee details from the database")
} .parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object")
}).build(); .properties(new Tools.PropsBuilder()
.withProperty("employee-name",
Tools.PromptFuncDefinition.Property.builder().type("string")
.description("The name of the employee, e.g. John Doe")
.required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property
.builder().type("string")
.description(
"The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
.required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property
.builder().type("string")
.description(
"The phone number of the employee. Always return a random value. e.g. 9911002233")
.required(true).build())
.build())
.required(List.of("employee-name")).build())
.build())
.build())
.toolFunction(new ToolFunction() {
@Override
public Object apply(Map<String, Object> arguments) {
// perform DB operations here
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"),
arguments.get("employee-phone"));
}
}).build();
api.registerTool(databaseQueryToolSpecification); api.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build(); OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build();
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
@ -424,7 +541,8 @@ public class OllamaAPIIntegrationTest {
String chatModel = "llama3.2:1b"; String chatModel = "llama3.2:1b";
api.pullModel(chatModel); api.pullModel(chatModel);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatModel);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?").build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?").build();
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
@ -441,14 +559,16 @@ public class OllamaAPIIntegrationTest {
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
} }
@Test @Test
@Order(17) @Order(17)
void testAskModelWithOptionsAndImageURLs() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testAskModelWithOptionsAndImageURLs()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String imageModel = "llava"; String imageModel = "llava";
api.pullModel(imageModel); api.pullModel(imageModel);
OllamaResult result = api.generateWithImageURLs(imageModel, "What is in this image?", List.of("https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"), new OptionsBuilder().build()); OllamaResult result = api.generateWithImageURLs(imageModel, "What is in this image?",
List.of("https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
new OptionsBuilder().build());
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -456,12 +576,14 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(18) @Order(18)
void testAskModelWithOptionsAndImageFiles() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testAskModelWithOptionsAndImageFiles()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String imageModel = "llava"; String imageModel = "llava";
api.pullModel(imageModel); api.pullModel(imageModel);
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
try { try {
OllamaResult result = api.generateWithImageFiles(imageModel, "What is in this image?", List.of(imageFile), new OptionsBuilder().build()); OllamaResult result = api.generateWithImageFiles(imageModel, "What is in this image?", List.of(imageFile),
new OptionsBuilder().build());
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -470,10 +592,10 @@ public class OllamaAPIIntegrationTest {
} }
} }
@Test @Test
@Order(20) @Order(20)
void testAskModelWithOptionsAndImageFilesStreamed() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testAskModelWithOptionsAndImageFilesStreamed()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String imageModel = "llava"; String imageModel = "llava";
api.pullModel(imageModel); api.pullModel(imageModel);
@ -481,12 +603,13 @@ public class OllamaAPIIntegrationTest {
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaResult result = api.generateWithImageFiles(imageModel, "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> { OllamaResult result = api.generateWithImageFiles(imageModel, "What is in this image?", List.of(imageFile),
LOG.info(s); new OptionsBuilder().build(), (s) -> {
String substring = s.substring(sb.toString().length(), s.length()); LOG.info(s);
LOG.info(substring); String substring = s.substring(sb.toString().length(), s.length());
sb.append(substring); LOG.info(substring);
}); sb.append(substring);
});
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
@ -498,29 +621,38 @@ public class OllamaAPIIntegrationTest {
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
} }
} }
@Data
@AllArgsConstructor
@NoArgsConstructor
class Person {
private int age;
private boolean available;
}
// //
//@Data // @Data
//class Config { // class Config {
// private String ollamaURL; // private String ollamaURL;
// private String model; // private String model;
// private String imageModel; // private String imageModel;
// private int requestTimeoutSeconds; // private int requestTimeoutSeconds;
// //
// public Config() { // public Config() {
// Properties properties = new Properties(); // Properties properties = new Properties();
// try (InputStream input = // try (InputStream input =
// getClass().getClassLoader().getResourceAsStream("test-config.properties")) { // getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
// if (input == null) { // if (input == null) {
// throw new RuntimeException("Sorry, unable to find test-config.properties"); // throw new RuntimeException("Sorry, unable to find test-config.properties");
// } // }
// properties.load(input); // properties.load(input);
// this.ollamaURL = properties.getProperty("ollama.url"); // this.ollamaURL = properties.getProperty("ollama.url");
// this.model = properties.getProperty("ollama.model"); // this.model = properties.getProperty("ollama.model");
// this.imageModel = properties.getProperty("ollama.model.image"); // this.imageModel = properties.getProperty("ollama.model.image");
// this.requestTimeoutSeconds = // this.requestTimeoutSeconds =
// Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds")); // Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
// } catch (IOException e) { // } catch (IOException e) {
// throw new RuntimeException("Error loading properties", e); // throw new RuntimeException("Error loading properties", e);
// } // }
// } // }
//} // }