mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 03:47:13 +02:00
Enhance OllamaAPI and OllamaResult for improved model pulling and structured responses
- Added a retry mechanism in OllamaAPI for model pulling, allowing configurable retries. - Introduced new methods in OllamaResult for structured response handling, including parsing JSON responses into a Map or specific class types. - Updated integration tests to validate the new functionality and ensure robust testing of model interactions. - Improved code formatting and consistency across the OllamaAPI and integration test classes.
This commit is contained in:
parent
1bda78e35b
commit
bc2a931586
@ -51,7 +51,7 @@ import java.util.stream.Collectors;
|
||||
/**
|
||||
* The base Ollama API class.
|
||||
*/
|
||||
@SuppressWarnings({"DuplicatedCode", "resource"})
|
||||
@SuppressWarnings({ "DuplicatedCode", "resource" })
|
||||
public class OllamaAPI {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
|
||||
@ -74,6 +74,12 @@ public class OllamaAPI {
|
||||
|
||||
private Auth auth;
|
||||
|
||||
private int numberOfRetriesForModelPull = 0;
|
||||
|
||||
public void setNumberOfRetriesForModelPull(int numberOfRetriesForModelPull) {
|
||||
this.numberOfRetriesForModelPull = numberOfRetriesForModelPull;
|
||||
}
|
||||
|
||||
private final ToolRegistry toolRegistry = new ToolRegistry();
|
||||
|
||||
/**
|
||||
@ -376,6 +382,26 @@ public class OllamaAPI {
|
||||
*/
|
||||
public void pullModel(String modelName)
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
if (numberOfRetriesForModelPull == 0) {
|
||||
this.doPullModel(modelName);
|
||||
} else {
|
||||
int numberOfRetries = 0;
|
||||
while (numberOfRetries < numberOfRetriesForModelPull) {
|
||||
try {
|
||||
this.doPullModel(modelName);
|
||||
return;
|
||||
} catch (OllamaBaseException e) {
|
||||
logger.error("Failed to pull model " + modelName + ", retrying...");
|
||||
numberOfRetries++;
|
||||
}
|
||||
}
|
||||
throw new OllamaBaseException(
|
||||
"Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries");
|
||||
}
|
||||
}
|
||||
|
||||
private void doPullModel(String modelName)
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
String url = this.host + "/api/pull";
|
||||
String jsonData = new ModelRequest(modelName).toString();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url))
|
||||
@ -742,8 +768,10 @@ public class OllamaAPI {
|
||||
* @param model The name or identifier of the AI model to use for generating
|
||||
* the response.
|
||||
* @param prompt The input text or prompt to provide to the AI model.
|
||||
* @param format A map containing the format specification for the structured output.
|
||||
* @return An instance of {@link OllamaResult} containing the structured response.
|
||||
* @param format A map containing the format specification for the structured
|
||||
* output.
|
||||
* @return An instance of {@link OllamaResult} containing the structured
|
||||
* response.
|
||||
* @throws OllamaBaseException if the response indicates an error status.
|
||||
* @throws IOException if an I/O error occurs during the HTTP request.
|
||||
* @throws InterruptedException if the operation is interrupted.
|
||||
@ -771,7 +799,11 @@ public class OllamaAPI {
|
||||
String responseBody = response.body();
|
||||
|
||||
if (statusCode == 200) {
|
||||
return Utils.getObjectMapper().readValue(responseBody, OllamaResult.class);
|
||||
OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody,
|
||||
OllamaStructuredResult.class);
|
||||
OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(),
|
||||
structuredResult.getResponseTime(), statusCode);
|
||||
return ollamaResult;
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseBody);
|
||||
}
|
||||
|
@ -1,15 +1,22 @@
|
||||
package io.github.ollama4j.models.response;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/** The type Ollama result. */
|
||||
@Getter
|
||||
@SuppressWarnings("unused")
|
||||
@Data
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
public class OllamaResult {
|
||||
/**
|
||||
* -- GETTER --
|
||||
@ -44,9 +51,68 @@ public class OllamaResult {
|
||||
@Override
|
||||
public String toString() {
|
||||
try {
|
||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||
Map<String, Object> responseMap = new HashMap<>();
|
||||
responseMap.put("response", this.response);
|
||||
responseMap.put("httpStatusCode", this.httpStatusCode);
|
||||
responseMap.put("responseTime", this.responseTime);
|
||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the structured response if the response is a JSON object.
|
||||
*
|
||||
* @return Map - structured response
|
||||
* @throws IllegalArgumentException if the response is not a valid JSON object
|
||||
*/
|
||||
public Map<String, Object> getStructuredResponse() {
|
||||
String responseStr = this.getResponse();
|
||||
if (responseStr == null || responseStr.trim().isEmpty()) {
|
||||
throw new IllegalArgumentException("Response is empty or null");
|
||||
}
|
||||
|
||||
try {
|
||||
// Check if the response is a valid JSON
|
||||
if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) ||
|
||||
(!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) {
|
||||
throw new IllegalArgumentException("Response is not a valid JSON object");
|
||||
}
|
||||
|
||||
Map<String, Object> response = getObjectMapper().readValue(responseStr,
|
||||
new TypeReference<Map<String, Object>>() {
|
||||
});
|
||||
return response;
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the structured response mapped to a specific class type.
|
||||
*
|
||||
* @param <T> The type of class to map the response to
|
||||
* @param clazz The class to map the response to
|
||||
* @return An instance of the specified class with the response data
|
||||
* @throws IllegalArgumentException if the response is not a valid JSON or is empty
|
||||
* @throws RuntimeException if there is an error mapping the response
|
||||
*/
|
||||
public <T> T as(Class<T> clazz) {
|
||||
String responseStr = this.getResponse();
|
||||
if (responseStr == null || responseStr.trim().isEmpty()) {
|
||||
throw new IllegalArgumentException("Response is empty or null");
|
||||
}
|
||||
|
||||
try {
|
||||
// Check if the response is a valid JSON
|
||||
if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) ||
|
||||
(!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) {
|
||||
throw new IllegalArgumentException("Response is not a valid JSON object");
|
||||
}
|
||||
return getObjectMapper().readValue(responseStr, clazz);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,77 @@
|
||||
package io.github.ollama4j.models.response;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Getter
|
||||
@SuppressWarnings("unused")
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
public class OllamaStructuredResult {
|
||||
private String response;
|
||||
|
||||
private int httpStatusCode;
|
||||
|
||||
private long responseTime = 0;
|
||||
|
||||
private String model;
|
||||
|
||||
public OllamaStructuredResult(String response, long responseTime, int httpStatusCode) {
|
||||
this.response = response;
|
||||
this.responseTime = responseTime;
|
||||
this.httpStatusCode = httpStatusCode;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
try {
|
||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the structured response if the response is a JSON object.
|
||||
*
|
||||
* @return Map - structured response
|
||||
*/
|
||||
public Map<String, Object> getStructuredResponse() {
|
||||
try {
|
||||
Map<String, Object> response = getObjectMapper().readValue(this.getResponse(),
|
||||
new TypeReference<Map<String, Object>>() {
|
||||
});
|
||||
return response;
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the structured response mapped to a specific class type.
|
||||
*
|
||||
* @param <T> The type of class to map the response to
|
||||
* @param clazz The class to map the response to
|
||||
* @return An instance of the specified class with the response data
|
||||
* @throws RuntimeException if there is an error mapping the response
|
||||
*/
|
||||
public <T> T getStructuredResponse(Class<T> clazz) {
|
||||
try {
|
||||
return getObjectMapper().readValue(this.getResponse(), clazz);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
@ -35,10 +35,10 @@ import java.util.*;
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@OllamaToolService(providers = {AnnotatedTool.class})
|
||||
@OllamaToolService(providers = { AnnotatedTool.class })
|
||||
@TestMethodOrder(OrderAnnotation.class)
|
||||
|
||||
@SuppressWarnings({"HttpUrlsUsage", "SpellCheckingInspection"})
|
||||
@SuppressWarnings({ "HttpUrlsUsage", "SpellCheckingInspection" })
|
||||
public class OllamaAPIIntegrationTest {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPIIntegrationTest.class);
|
||||
|
||||
@ -46,12 +46,11 @@ public class OllamaAPIIntegrationTest {
|
||||
private static OllamaAPI api;
|
||||
|
||||
private static final String EMBEDDING_MODEL_MINILM = "all-minilm";
|
||||
private static final String CHAT_MODEL_DEFAULT = "qwen2.5:0.5b";
|
||||
private static final String CHAT_MODEL_QWEN_SMALL = "qwen2.5:0.5b";
|
||||
private static final String CHAT_MODEL_INSTRUCT = "qwen2.5:0.5b-instruct";
|
||||
private static final String CHAT_MODEL_SYSTEM_PROMPT = "llama3.2:1b";
|
||||
private static final String CHAT_MODEL_LLAMA3 = "llama3";
|
||||
private static final String IMAGE_MODEL_LLAVA = "llava";
|
||||
private static final String IMAGE_MODEL_MOONDREAM = "moondream";
|
||||
|
||||
@BeforeAll
|
||||
public static void setUp() {
|
||||
@ -61,7 +60,8 @@ public class OllamaAPIIntegrationTest {
|
||||
if (useExternalOllamaHost) {
|
||||
api = new OllamaAPI(ollamaHost);
|
||||
} else {
|
||||
throw new RuntimeException("USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers Ollama host for the tests now. If you would like to use an external host, please set the env var to USE_EXTERNAL_OLLAMA_HOST=true and set the env var OLLAMA_HOST=http://localhost:11435 or a different host/port.");
|
||||
throw new RuntimeException(
|
||||
"USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers Ollama host for the tests now. If you would like to use an external host, please set the env var to USE_EXTERNAL_OLLAMA_HOST=true and set the env var OLLAMA_HOST=http://localhost:11435 or a different host/port.");
|
||||
}
|
||||
} catch (Exception e) {
|
||||
String ollamaVersion = "0.6.1";
|
||||
@ -77,6 +77,7 @@ public class OllamaAPIIntegrationTest {
|
||||
}
|
||||
api.setRequestTimeoutSeconds(120);
|
||||
api.setVerbose(true);
|
||||
api.setNumberOfRetriesForModelPull(3);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -92,12 +93,14 @@ public class OllamaAPIIntegrationTest {
|
||||
// String expectedVersion = ollama.getDockerImageName().split(":")[1];
|
||||
String actualVersion = api.getVersion();
|
||||
assertNotNull(actualVersion);
|
||||
// assertEquals(expectedVersion, actualVersion, "Version should match the Docker image version");
|
||||
// assertEquals(expectedVersion, actualVersion, "Version should match the Docker
|
||||
// image version");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
public void testListModelsAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
public void testListModelsAPI()
|
||||
throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
api.pullModel(EMBEDDING_MODEL_MINILM);
|
||||
// Fetch the list of models
|
||||
List<Model> models = api.listModels();
|
||||
@ -109,7 +112,8 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
void testListModelsFromLibrary() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
void testListModelsFromLibrary()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
List<LibraryModel> models = api.listModelsFromLibrary();
|
||||
assertNotNull(models);
|
||||
assertFalse(models.isEmpty());
|
||||
@ -117,7 +121,8 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
public void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
public void testPullModelAPI()
|
||||
throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
|
||||
api.pullModel(EMBEDDING_MODEL_MINILM);
|
||||
List<Model> models = api.listModels();
|
||||
assertNotNull(models, "Models should not be null");
|
||||
@ -143,26 +148,18 @@ public class OllamaAPIIntegrationTest {
|
||||
assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(6)
|
||||
void testAskModelWithDefaultOptions()
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
api.pullModel(CHAT_MODEL_DEFAULT);
|
||||
OllamaResult result = api.generate(CHAT_MODEL_DEFAULT,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?", false,
|
||||
new OptionsBuilder().build());
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(6)
|
||||
void testAskModelWithStructuredOutput()
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
api.pullModel(CHAT_MODEL_DEFAULT);
|
||||
api.pullModel(CHAT_MODEL_QWEN_SMALL);
|
||||
|
||||
int age = 28;
|
||||
boolean available = false;
|
||||
|
||||
String prompt = "Batman is " + age + " years old and is " + (available ? "available" : "not available")
|
||||
+ " because he is busy saving Gotham City. Respond using JSON";
|
||||
|
||||
String prompt = "Ollama is 22 years old and is busy saving the world. Respond using JSON";
|
||||
Map<String, Object> format = new HashMap<>();
|
||||
format.put("type", "object");
|
||||
format.put("properties", new HashMap<String, Object>() {
|
||||
@ -181,42 +178,44 @@ public class OllamaAPIIntegrationTest {
|
||||
});
|
||||
format.put("required", Arrays.asList("age", "available"));
|
||||
|
||||
OllamaResult result = api.generate(CHAT_MODEL_DEFAULT, prompt, format);
|
||||
OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL, prompt, format);
|
||||
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
|
||||
Map<String, Object> actualResponse = getObjectMapper().readValue(result.getResponse(), new TypeReference<>() {
|
||||
});
|
||||
|
||||
int age = 22;
|
||||
boolean available = true;
|
||||
String expectedResponseJson = "{\n \"age\": " + age + ",\n \"available\": " + available + "\n}";
|
||||
|
||||
Map<String, Object> expectedResponse = getObjectMapper().readValue(expectedResponseJson,
|
||||
new TypeReference<Map<String, Object>>() {
|
||||
});
|
||||
assertEquals(actualResponse.get("age").toString(), expectedResponse.get("age").toString());
|
||||
assertEquals(actualResponse.get("available").toString(), expectedResponse.get("available").toString());
|
||||
System.out.println(result);
|
||||
|
||||
assertEquals(result.getStructuredResponse().get("age").toString(),
|
||||
result.getStructuredResponse().get("age").toString());
|
||||
assertEquals(result.getStructuredResponse().get("available").toString(),
|
||||
result.getStructuredResponse().get("available").toString());
|
||||
|
||||
Person person = result.getStructuredResponse(Person.class);
|
||||
Person person = result.as(Person.class);
|
||||
assertEquals(person.getAge(), age);
|
||||
assertEquals(person.isAvailable(), available);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(6)
|
||||
void testAskModelWithDefaultOptions()
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
api.pullModel(CHAT_MODEL_QWEN_SMALL);
|
||||
OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?", false,
|
||||
new OptionsBuilder().build());
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
assertFalse(result.getResponse().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(7)
|
||||
void testAskModelWithDefaultOptionsStreamed()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(CHAT_MODEL_DEFAULT);
|
||||
api.pullModel(CHAT_MODEL_QWEN_SMALL);
|
||||
StringBuffer sb = new StringBuffer();
|
||||
OllamaResult result = api.generate(CHAT_MODEL_DEFAULT,
|
||||
OllamaResult result = api.generate(CHAT_MODEL_QWEN_SMALL,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?", false,
|
||||
new OptionsBuilder().build(), (s) -> {
|
||||
LOG.info(s);
|
||||
@ -233,7 +232,8 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
@Test
|
||||
@Order(8)
|
||||
void testAskModelWithOptions() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
void testAskModelWithOptions()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(CHAT_MODEL_INSTRUCT);
|
||||
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_INSTRUCT);
|
||||
@ -252,7 +252,8 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
@Test
|
||||
@Order(9)
|
||||
void testChatWithSystemPrompt() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
void testChatWithSystemPrompt()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
|
||||
@ -277,7 +278,8 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
// Create the initial user question
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.").build();
|
||||
.withMessage(OllamaChatMessageRole.USER, "What is 1+1? Answer only in numbers.")
|
||||
.build();
|
||||
|
||||
// Start conversation with model
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
@ -297,26 +299,32 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
// Create the next user question: the third question
|
||||
requestModel = builder.withMessages(chatResult.getChatHistory())
|
||||
.withMessage(OllamaChatMessageRole.USER, "What is the largest value between 2, 4 and 6?").build();
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"What is the largest value between 2, 4 and 6?")
|
||||
.build();
|
||||
|
||||
// Continue conversation with the model for the third question
|
||||
chatResult = api.chat(requestModel);
|
||||
|
||||
// verify the result
|
||||
assertNotNull(chatResult, "Chat result should not be null");
|
||||
assertTrue(chatResult.getChatHistory().size() > 2, "Chat history should contain more than two messages");
|
||||
assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent().contains("6"),
|
||||
assertTrue(chatResult.getChatHistory().size() > 2,
|
||||
"Chat history should contain more than two messages");
|
||||
assertTrue(chatResult.getChatHistory().get(chatResult.getChatHistory().size() - 1).getContent()
|
||||
.contains("6"),
|
||||
"Response should contain '6'");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(10)
|
||||
void testChatWithImageFromURL() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
void testChatWithImageFromURL()
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
api.pullModel(IMAGE_MODEL_LLAVA);
|
||||
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA);
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", Collections.emptyList(),
|
||||
.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
||||
Collections.emptyList(),
|
||||
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
|
||||
.build();
|
||||
api.registerAnnotatedTools(new OllamaAPIIntegrationTest());
|
||||
@ -329,10 +337,12 @@ public class OllamaAPIIntegrationTest {
|
||||
@Order(10)
|
||||
void testChatWithImageFromFileWithHistoryRecognition()
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
api.pullModel(IMAGE_MODEL_MOONDREAM);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_MOONDREAM);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
||||
Collections.emptyList(), List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
|
||||
api.pullModel(IMAGE_MODEL_LLAVA);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(IMAGE_MODEL_LLAVA);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||
"What's in the picture?",
|
||||
Collections.emptyList(), List.of(getImageFileFromClasspath("dog-on-a-boat.jpg")))
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
@ -355,41 +365,58 @@ public class OllamaAPIIntegrationTest {
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
|
||||
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
||||
.functionName("get-employee-details").functionDescription("Get employee details from the database")
|
||||
.functionName("get-employee-details")
|
||||
.functionDescription("Get employee details from the database")
|
||||
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
|
||||
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details")
|
||||
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||
.name("get-employee-details")
|
||||
.description("Get employee details from the database")
|
||||
.parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object")
|
||||
.parameters(Tools.PromptFuncDefinition.Parameters
|
||||
.builder().type("object")
|
||||
.properties(new Tools.PropsBuilder()
|
||||
.withProperty("employee-name",
|
||||
Tools.PromptFuncDefinition.Property.builder().type("string")
|
||||
Tools.PromptFuncDefinition.Property
|
||||
.builder()
|
||||
.type("string")
|
||||
.description("The name of the employee, e.g. John Doe")
|
||||
.required(true).build())
|
||||
.withProperty("employee-address", Tools.PromptFuncDefinition.Property
|
||||
.builder().type("string")
|
||||
.required(true)
|
||||
.build())
|
||||
.withProperty("employee-address",
|
||||
Tools.PromptFuncDefinition.Property
|
||||
.builder()
|
||||
.type("string")
|
||||
.description(
|
||||
"The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
|
||||
.required(true).build())
|
||||
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property
|
||||
.builder().type("string")
|
||||
.required(true)
|
||||
.build())
|
||||
.withProperty("employee-phone",
|
||||
Tools.PromptFuncDefinition.Property
|
||||
.builder()
|
||||
.type("string")
|
||||
.description(
|
||||
"The phone number of the employee. Always return a random value. e.g. 9911002233")
|
||||
.required(true).build())
|
||||
.required(true)
|
||||
.build())
|
||||
.build())
|
||||
.required(List.of("employee-name"))
|
||||
.build())
|
||||
.required(List.of("employee-name")).build())
|
||||
.build())
|
||||
.build())
|
||||
.toolFunction(arguments -> {
|
||||
// perform DB operations here
|
||||
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
|
||||
UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"),
|
||||
return String.format(
|
||||
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
|
||||
UUID.randomUUID(), arguments.get("employee-name"),
|
||||
arguments.get("employee-address"),
|
||||
arguments.get("employee-phone"));
|
||||
}).build();
|
||||
|
||||
api.registerTool(databaseQueryToolSpecification);
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build();
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Give me the ID of the employee named 'Rahul Kumar'?")
|
||||
.build();
|
||||
|
||||
OllamaChatResult chatResult = api.chat(requestModel);
|
||||
assertNotNull(chatResult);
|
||||
@ -451,7 +478,8 @@ public class OllamaAPIIntegrationTest {
|
||||
api.registerAnnotatedTools(new AnnotatedTool());
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER, "Greet Pedro with a lot of hearts and respond to me, "
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Greet Pedro with a lot of hearts and respond to me, "
|
||||
+ "and state how many emojis have been in your greeting")
|
||||
.build();
|
||||
|
||||
@ -484,36 +512,51 @@ public class OllamaAPIIntegrationTest {
|
||||
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
||||
.functionName("get-employee-details").functionDescription("Get employee details from the database")
|
||||
.functionName("get-employee-details")
|
||||
.functionDescription("Get employee details from the database")
|
||||
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
|
||||
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name("get-employee-details")
|
||||
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||
.name("get-employee-details")
|
||||
.description("Get employee details from the database")
|
||||
.parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object")
|
||||
.parameters(Tools.PromptFuncDefinition.Parameters
|
||||
.builder().type("object")
|
||||
.properties(new Tools.PropsBuilder()
|
||||
.withProperty("employee-name",
|
||||
Tools.PromptFuncDefinition.Property.builder().type("string")
|
||||
Tools.PromptFuncDefinition.Property
|
||||
.builder()
|
||||
.type("string")
|
||||
.description("The name of the employee, e.g. John Doe")
|
||||
.required(true).build())
|
||||
.withProperty("employee-address", Tools.PromptFuncDefinition.Property
|
||||
.builder().type("string")
|
||||
.required(true)
|
||||
.build())
|
||||
.withProperty("employee-address",
|
||||
Tools.PromptFuncDefinition.Property
|
||||
.builder()
|
||||
.type("string")
|
||||
.description(
|
||||
"The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India")
|
||||
.required(true).build())
|
||||
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property
|
||||
.builder().type("string")
|
||||
.required(true)
|
||||
.build())
|
||||
.withProperty("employee-phone",
|
||||
Tools.PromptFuncDefinition.Property
|
||||
.builder()
|
||||
.type("string")
|
||||
.description(
|
||||
"The phone number of the employee. Always return a random value. e.g. 9911002233")
|
||||
.required(true).build())
|
||||
.required(true)
|
||||
.build())
|
||||
.build())
|
||||
.required(List.of("employee-name"))
|
||||
.build())
|
||||
.required(List.of("employee-name")).build())
|
||||
.build())
|
||||
.build())
|
||||
.toolFunction(new ToolFunction() {
|
||||
@Override
|
||||
public Object apply(Map<String, Object> arguments) {
|
||||
// perform DB operations here
|
||||
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
|
||||
UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"),
|
||||
return String.format(
|
||||
"Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}",
|
||||
UUID.randomUUID(), arguments.get("employee-name"),
|
||||
arguments.get("employee-address"),
|
||||
arguments.get("employee-phone"));
|
||||
}
|
||||
}).build();
|
||||
@ -521,7 +564,9 @@ public class OllamaAPIIntegrationTest {
|
||||
api.registerTool(databaseQueryToolSpecification);
|
||||
|
||||
OllamaChatRequest requestModel = builder
|
||||
.withMessage(OllamaChatMessageRole.USER, "Give me the ID of the employee named 'Rahul Kumar'?").build();
|
||||
.withMessage(OllamaChatMessageRole.USER,
|
||||
"Give me the ID of the employee named 'Rahul Kumar'?")
|
||||
.build();
|
||||
|
||||
StringBuffer sb = new StringBuffer();
|
||||
|
||||
@ -544,7 +589,8 @@ public class OllamaAPIIntegrationTest {
|
||||
api.pullModel(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_SYSTEM_PROMPT);
|
||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER,
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?").build();
|
||||
"What is the capital of France? And what's France's connection with Mona Lisa?")
|
||||
.build();
|
||||
|
||||
StringBuffer sb = new StringBuffer();
|
||||
|
||||
@ -582,7 +628,8 @@ public class OllamaAPIIntegrationTest {
|
||||
api.pullModel(IMAGE_MODEL_LLAVA);
|
||||
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
|
||||
try {
|
||||
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile),
|
||||
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?",
|
||||
List.of(imageFile),
|
||||
new OptionsBuilder().build());
|
||||
assertNotNull(result);
|
||||
assertNotNull(result.getResponse());
|
||||
@ -602,7 +649,8 @@ public class OllamaAPIIntegrationTest {
|
||||
|
||||
StringBuffer sb = new StringBuffer();
|
||||
|
||||
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile),
|
||||
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?",
|
||||
List.of(imageFile),
|
||||
new OptionsBuilder().build(), (s) -> {
|
||||
LOG.info(s);
|
||||
String substring = s.substring(sb.toString().length(), s.length());
|
||||
@ -628,30 +676,3 @@ class Person {
|
||||
private int age;
|
||||
private boolean available;
|
||||
}
|
||||
|
||||
//
|
||||
// @Data
|
||||
// class Config {
|
||||
// private String ollamaURL;
|
||||
// private String model;
|
||||
// private String imageModel;
|
||||
// private int requestTimeoutSeconds;
|
||||
//
|
||||
// public Config() {
|
||||
// Properties properties = new Properties();
|
||||
// try (InputStream input =
|
||||
// getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
|
||||
// if (input == null) {
|
||||
// throw new RuntimeException("Sorry, unable to find test-config.properties");
|
||||
// }
|
||||
// properties.load(input);
|
||||
// this.ollamaURL = properties.getProperty("ollama.url");
|
||||
// this.model = properties.getProperty("ollama.model");
|
||||
// this.imageModel = properties.getProperty("ollama.model.image");
|
||||
// this.requestTimeoutSeconds =
|
||||
// Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
|
||||
// } catch (IOException e) {
|
||||
// throw new RuntimeException("Error loading properties", e);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
Loading…
x
Reference in New Issue
Block a user