Added support for tools/function calling - specifically for Mistral's latest model.

This commit is contained in:
Amith Koujalgi
2024-07-12 17:06:41 +05:30
parent 8ef6fac28e
commit 91ee6cb4c1
15 changed files with 1006 additions and 490 deletions

View File

@@ -1,7 +1,5 @@
package io.github.amithkoujalgi.ollama4j.integrationtests;
import static org.junit.jupiter.api.Assertions.*;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
@@ -10,9 +8,16 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import lombok.Data;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
@@ -22,372 +27,369 @@ import java.net.http.HttpConnectTimeoutException;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import lombok.Data;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.junit.jupiter.api.Assertions.*;
class TestRealAPIs {
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
OllamaAPI ollamaAPI;
Config config;
OllamaAPI ollamaAPI;
Config config;
private File getImageFileFromClasspath(String fileName) {
ClassLoader classLoader = getClass().getClassLoader();
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
}
@BeforeEach
void setUp() {
config = new Config();
ollamaAPI = new OllamaAPI(config.getOllamaURL());
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
}
@Test
@Order(1)
void testWrongEndpoint() {
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
assertThrows(ConnectException.class, ollamaAPI::listModels);
}
@Test
@Order(1)
void testEndpointReachability() {
try {
assertNotNull(ollamaAPI.listModels());
} catch (HttpConnectTimeoutException e) {
fail(e.getMessage());
} catch (Exception e) {
fail(e);
private File getImageFileFromClasspath(String fileName) {
ClassLoader classLoader = getClass().getClassLoader();
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
}
}
@Test
@Order(2)
void testListModels() {
testEndpointReachability();
try {
assertNotNull(ollamaAPI.listModels());
ollamaAPI.listModels().forEach(System.out::println);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
@BeforeEach
void setUp() {
config = new Config();
ollamaAPI = new OllamaAPI(config.getOllamaURL());
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
}
}
@Test
@Order(2)
void testPullModel() {
testEndpointReachability();
try {
ollamaAPI.pullModel(config.getModel());
boolean found =
ollamaAPI.listModels().stream()
.anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
assertTrue(found);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
@Test
@Order(1)
void testWrongEndpoint() {
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
assertThrows(ConnectException.class, ollamaAPI::listModels);
}
}
@Test
@Order(3)
void testListDtails() {
testEndpointReachability();
try {
ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel());
assertNotNull(modelDetails);
System.out.println(modelDetails);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
@Test
@Order(1)
void testEndpointReachability() {
try {
assertNotNull(ollamaAPI.listModels());
} catch (HttpConnectTimeoutException e) {
fail(e.getMessage());
} catch (Exception e) {
fail(e);
}
}
}
@Test
@Order(3)
void testAskModelWithDefaultOptions() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generate(
config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
@Test
@Order(2)
void testListModels() {
testEndpointReachability();
try {
assertNotNull(ollamaAPI.listModels());
ollamaAPI.listModels().forEach(System.out::println);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testAskModelWithDefaultOptionsStreamed() {
testEndpointReachability();
try {
StringBuffer sb = new StringBuffer("");
OllamaResult result = ollamaAPI.generate(config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
@Test
@Order(2)
void testPullModel() {
testEndpointReachability();
try {
ollamaAPI.pullModel(config.getModel());
boolean found =
ollamaAPI.listModels().stream()
.anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
assertTrue(found);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testAskModelWithOptions() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generate(
config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
new OptionsBuilder().setTemperature(0.9f).build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
@Test
@Order(3)
void testListDtails() {
testEndpointReachability();
try {
ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel());
assertNotNull(modelDetails);
System.out.println(modelDetails);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testChat() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
.withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
.withMessage(OllamaChatMessageRole.USER,"And what is the second larges city?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank());
assertEquals(4,chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
@Test
@Order(3)
void testAskModelWithDefaultOptions() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generate(
config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
false,
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testChatWithSystemPrompt() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
@Test
@Order(3)
void testAskModelWithDefaultOptionsStreamed() {
testEndpointReachability();
try {
StringBuffer sb = new StringBuffer("");
OllamaResult result = ollamaAPI.generate(config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
false,
new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank());
assertTrue(chatResult.getResponse().startsWith("NI"));
assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testChatWithStream() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
StringBuffer sb = new StringBuffer("");
OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
@Test
@Order(3)
void testAskModelWithOptions() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generate(
config.getModel(),
"What is the capital of France? And what's France's connection with Mona Lisa?",
true,
new OptionsBuilder().setTemperature(0.9f).build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testChatWithImageFromFileWithHistoryRecognition() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder =
OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequestModel requestModel =
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
@Test
@Order(3)
void testChat() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
.withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
.withMessage(OllamaChatMessageRole.USER, "And what is the second larges city?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponse());
builder.reset();
requestModel =
builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponse());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank());
assertEquals(4, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testChatWithImageFromURL() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build();
@Test
@Order(3)
void testChatWithSystemPrompt() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank());
assertTrue(chatResult.getResponse().startsWith("NI"));
assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageFiles() {
testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
try {
OllamaResult result =
ollamaAPI.generateWithImageFiles(
config.getImageModel(),
"What is in this image?",
List.of(imageFile),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
@Test
@Order(3)
void testChatWithStream() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
StringBuffer sb = new StringBuffer("");
OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageFilesStreamed() {
testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
try {
StringBuffer sb = new StringBuffer("");
@Test
@Order(3)
void testChatWithImageFromFileWithHistoryRecognition() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder =
OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequestModel requestModel =
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
"What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponse());
builder.reset();
requestModel =
builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponse());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageURLs() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generateWithImageURLs(
config.getImageModel(),
"What is in this image?",
List.of(
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
@Test
@Order(3)
void testChatWithImageFromURL() {
testEndpointReachability();
try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
@Test
@Order(3)
public void testEmbedding() {
testEndpointReachability();
try {
OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
.getInstance(config.getModel(), "What is the capital of France?").build();
List<Double> embeddings = ollamaAPI.generateEmbeddings(request);
assertNotNull(embeddings);
assertFalse(embeddings.isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
@Test
@Order(3)
void testAskModelWithOptionsAndImageFiles() {
testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
try {
OllamaResult result =
ollamaAPI.generateWithImageFiles(
config.getImageModel(),
"What is in this image?",
List.of(imageFile),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageFilesStreamed() {
testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
try {
StringBuffer sb = new StringBuffer("");
OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
"What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageURLs() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generateWithImageURLs(
config.getImageModel(),
"What is in this image?",
List.of(
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
}
}
@Test
@Order(3)
public void testEmbedding() {
testEndpointReachability();
try {
OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
.getInstance(config.getModel(), "What is the capital of France?").build();
List<Double> embeddings = ollamaAPI.generateEmbeddings(request);
assertNotNull(embeddings);
assertFalse(embeddings.isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
}
}
@Data
class Config {
private String ollamaURL;
private String model;
private String imageModel;
private int requestTimeoutSeconds;
private String ollamaURL;
private String model;
private String imageModel;
private int requestTimeoutSeconds;
public Config() {
Properties properties = new Properties();
try (InputStream input =
getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
if (input == null) {
throw new RuntimeException("Sorry, unable to find test-config.properties");
}
properties.load(input);
this.ollamaURL = properties.getProperty("ollama.url");
this.model = properties.getProperty("ollama.model");
this.imageModel = properties.getProperty("ollama.model.image");
this.requestTimeoutSeconds =
Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
} catch (IOException e) {
throw new RuntimeException("Error loading properties", e);
public Config() {
Properties properties = new Properties();
try (InputStream input =
getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
if (input == null) {
throw new RuntimeException("Sorry, unable to find test-config.properties");
}
properties.load(input);
this.ollamaURL = properties.getProperty("ollama.url");
this.model = properties.getProperty("ollama.model");
this.imageModel = properties.getProperty("ollama.model.image");
this.requestTimeoutSeconds =
Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
} catch (IOException e) {
throw new RuntimeException("Error loading properties", e);
}
}
}
}

View File

@@ -1,7 +1,5 @@
package io.github.amithkoujalgi.ollama4j.unittests;
import static org.mockito.Mockito.*;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
@@ -9,155 +7,158 @@ import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collections;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import static org.mockito.Mockito.*;
class TestMockedAPIs {
@Test
void testPullModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
doNothing().when(ollamaAPI).pullModel(model);
ollamaAPI.pullModel(model);
verify(ollamaAPI, times(1)).pullModel(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
void testPullModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
doNothing().when(ollamaAPI).pullModel(model);
ollamaAPI.pullModel(model);
verify(ollamaAPI, times(1)).pullModel(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testListModels() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
try {
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
ollamaAPI.listModels();
verify(ollamaAPI, times(1)).listModels();
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
void testListModels() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
try {
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
ollamaAPI.listModels();
verify(ollamaAPI, times(1)).listModels();
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testCreateModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
try {
doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
void testCreateModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
try {
doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testDeleteModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
doNothing().when(ollamaAPI).deleteModel(model, true);
ollamaAPI.deleteModel(model, true);
verify(ollamaAPI, times(1)).deleteModel(model, true);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
void testDeleteModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
doNothing().when(ollamaAPI).deleteModel(model, true);
ollamaAPI.deleteModel(model, true);
verify(ollamaAPI, times(1)).deleteModel(model, true);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testGetModelDetails() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
ollamaAPI.getModelDetails(model);
verify(ollamaAPI, times(1)).getModelDetails(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
void testGetModelDetails() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
ollamaAPI.getModelDetails(model);
verify(ollamaAPI, times(1)).getModelDetails(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testGenerateEmbeddings() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
ollamaAPI.generateEmbeddings(model, prompt);
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
@Test
void testGenerateEmbeddings() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
ollamaAPI.generateEmbeddings(model, prompt);
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
OptionsBuilder optionsBuilder = new OptionsBuilder();
try {
when(ollamaAPI.generate(model, prompt, optionsBuilder.build()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generate(model, prompt, optionsBuilder.build());
verify(ollamaAPI, times(1)).generate(model, prompt, optionsBuilder.build());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
@Test
void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
OptionsBuilder optionsBuilder = new OptionsBuilder();
try {
when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generate(model, prompt, false, optionsBuilder.build());
verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testAskWithImageFiles() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1))
.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
@Test
void testAskWithImageFiles() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1))
.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testAskWithImageURLs() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1))
.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
@Test
void testAskWithImageURLs() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1))
.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
@Test
void testAskAsync() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
when(ollamaAPI.generateAsync(model, prompt))
.thenReturn(new OllamaAsyncResultCallback(null, null, 3));
ollamaAPI.generateAsync(model, prompt);
verify(ollamaAPI, times(1)).generateAsync(model, prompt);
}
@Test
void testAskAsync() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
when(ollamaAPI.generateAsync(model, prompt, false))
.thenReturn(new OllamaAsyncResultCallback(null, null, 3));
ollamaAPI.generateAsync(model, prompt, false);
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
}
}