Refactor OllamaAPI to Ollama class and update documentation

- Replaced instances of `OllamaAPI` with `Ollama` across the codebase for consistency.
- Updated example code snippets in documentation to reflect the new class name.
- Enhanced metrics collection setup in the documentation.
- Added integration tests for the new `Ollama` class to ensure functionality remains intact.
This commit is contained in:
amithkoujalgi
2025-09-28 23:30:02 +05:30
parent 6fce6ec777
commit 35bf3de62a
17 changed files with 326 additions and 317 deletions

View File

@@ -53,9 +53,9 @@ import org.slf4j.LoggerFactory;
* <p>This class provides methods for model management, chat, embeddings, tool registration, and more.
*/
@SuppressWarnings({"DuplicatedCode", "resource", "SpellCheckingInspection"})
public class OllamaAPI {
public class Ollama {
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
private static final Logger LOG = LoggerFactory.getLogger(Ollama.class);
private final String host;
private Auth auth;
@@ -107,7 +107,7 @@ public class OllamaAPI {
/**
* Instantiates the Ollama API with the default Ollama host: {@code http://localhost:11434}
*/
public OllamaAPI() {
public Ollama() {
this.host = "http://localhost:11434";
}
@@ -116,7 +116,7 @@ public class OllamaAPI {
*
* @param host the host address of the Ollama server
*/
public OllamaAPI(String host) {
public Ollama(String host) {
if (host.endsWith("/")) {
this.host = host.substring(0, host.length() - 1);
} else {

View File

@@ -8,7 +8,7 @@
*/
package io.github.ollama4j.tools.annotations;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.Ollama;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
@@ -18,7 +18,7 @@ import java.lang.annotation.Target;
* Annotation to mark a class as an Ollama tool service.
* <p>
* When a class is annotated with {@code @OllamaToolService}, the method
* {@link OllamaAPI#registerAnnotatedTools()} can be used to automatically register all tool provider
* {@link Ollama#registerAnnotatedTools()} can be used to automatically register all tool provider
* classes specified in the {@link #providers()} array. All methods in those provider classes that are
* annotated with {@link ToolSpec} will be registered as tools.
* </p>

View File

@@ -8,7 +8,7 @@
*/
package io.github.ollama4j.tools.annotations;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.Ollama;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
@@ -16,7 +16,7 @@ import java.lang.annotation.Target;
/**
* Annotation to mark a method as a tool that can be registered automatically by
* {@link OllamaAPI#registerAnnotatedTools()}.
* {@link Ollama#registerAnnotatedTools()}.
* <p>
* Methods annotated with {@code @ToolSpec} will be discovered and registered as tools
* when the containing class is specified as a provider in {@link OllamaToolService}.

View File

@@ -10,7 +10,7 @@ package io.github.ollama4j.integrationtests;
import static org.junit.jupiter.api.Assertions.*;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.Ollama;
import io.github.ollama4j.exceptions.OllamaException;
import io.github.ollama4j.impl.ConsoleOutputChatTokenHandler;
import io.github.ollama4j.impl.ConsoleOutputGenerateTokenHandler;
@@ -44,11 +44,11 @@ import org.testcontainers.ollama.OllamaContainer;
@OllamaToolService(providers = {AnnotatedTool.class})
@TestMethodOrder(OrderAnnotation.class)
@SuppressWarnings({"HttpUrlsUsage", "SpellCheckingInspection", "FieldCanBeLocal", "ConstantValue"})
class OllamaAPIIntegrationTest {
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPIIntegrationTest.class);
class OllamaIntegrationTest {
private static final Logger LOG = LoggerFactory.getLogger(OllamaIntegrationTest.class);
private static OllamaContainer ollama;
private static OllamaAPI api;
private static Ollama api;
private static final String EMBEDDING_MODEL = "all-minilm";
private static final String VISION_MODEL = "moondream:1.8b";
@@ -81,7 +81,7 @@ class OllamaAPIIntegrationTest {
Properties props = new Properties();
try {
props.load(
OllamaAPIIntegrationTest.class
OllamaIntegrationTest.class
.getClassLoader()
.getResourceAsStream("test-config.properties"));
} catch (Exception e) {
@@ -103,7 +103,7 @@ class OllamaAPIIntegrationTest {
if (useExternalOllamaHost) {
LOG.info("Using external Ollama host: {}", ollamaHost);
api = new OllamaAPI(ollamaHost);
api = new Ollama(ollamaHost);
} else {
throw new RuntimeException(
"USE_EXTERNAL_OLLAMA_HOST is not set so, we will be using Testcontainers"
@@ -124,7 +124,7 @@ class OllamaAPIIntegrationTest {
ollama.start();
LOG.info("Using Testcontainer Ollama host...");
api =
new OllamaAPI(
new Ollama(
"http://"
+ ollama.getHost()
+ ":"
@@ -143,8 +143,8 @@ class OllamaAPIIntegrationTest {
@Test
@Order(1)
void shouldThrowConnectExceptionForWrongEndpoint() {
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
assertThrows(OllamaException.class, ollamaAPI::listModels);
Ollama ollama = new Ollama("http://wrong-host:11434");
assertThrows(OllamaException.class, ollama::listModels);
}
/**
@@ -778,7 +778,7 @@ class OllamaAPIIntegrationTest {
Collections.emptyList(),
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
.build();
api.registerAnnotatedTools(new OllamaAPIIntegrationTest());
api.registerAnnotatedTools(new OllamaIntegrationTest());
OllamaChatResult chatResult = api.chat(requestModel, null);
assertNotNull(chatResult);

View File

@@ -10,7 +10,7 @@ package io.github.ollama4j.integrationtests;
import static org.junit.jupiter.api.Assertions.*;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.Ollama;
import io.github.ollama4j.exceptions.OllamaException;
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
import io.github.ollama4j.models.generate.OllamaGenerateRequestBuilder;
@@ -62,7 +62,7 @@ public class WithAuth {
private static OllamaContainer ollama;
private static GenericContainer<?> nginx;
private static OllamaAPI api;
private static Ollama api;
@BeforeAll
static void setUp() {
@@ -74,7 +74,7 @@ public class WithAuth {
LOG.info("Using Testcontainer Ollama host...");
api = new OllamaAPI("http://" + nginx.getHost() + ":" + nginx.getMappedPort(NGINX_PORT));
api = new Ollama("http://" + nginx.getHost() + ":" + nginx.getMappedPort(NGINX_PORT));
api.setRequestTimeoutSeconds(120);
api.setNumberOfRetriesForModelPull(3);

View File

@@ -12,7 +12,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.*;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.Ollama;
import io.github.ollama4j.exceptions.OllamaException;
import io.github.ollama4j.exceptions.RoleNotFoundException;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
@@ -36,12 +36,12 @@ import org.mockito.Mockito;
class TestMockedAPIs {
@Test
void testPullModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
String model = "llama2";
try {
doNothing().when(ollamaAPI).pullModel(model);
ollamaAPI.pullModel(model);
verify(ollamaAPI, times(1)).pullModel(model);
doNothing().when(ollama).pullModel(model);
ollama.pullModel(model);
verify(ollama, times(1)).pullModel(model);
} catch (OllamaException e) {
throw new RuntimeException(e);
}
@@ -49,11 +49,11 @@ class TestMockedAPIs {
@Test
void testListModels() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
try {
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
ollamaAPI.listModels();
verify(ollamaAPI, times(1)).listModels();
when(ollama.listModels()).thenReturn(new ArrayList<>());
ollama.listModels();
verify(ollama, times(1)).listModels();
} catch (OllamaException e) {
throw new RuntimeException(e);
}
@@ -61,7 +61,7 @@ class TestMockedAPIs {
@Test
void testCreateModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
CustomModelRequest customModelRequest =
CustomModelRequest.builder()
.model("mario")
@@ -69,9 +69,9 @@ class TestMockedAPIs {
.system("You are Mario from Super Mario Bros.")
.build();
try {
doNothing().when(ollamaAPI).createModel(customModelRequest);
ollamaAPI.createModel(customModelRequest);
verify(ollamaAPI, times(1)).createModel(customModelRequest);
doNothing().when(ollama).createModel(customModelRequest);
ollama.createModel(customModelRequest);
verify(ollama, times(1)).createModel(customModelRequest);
} catch (OllamaException e) {
throw new RuntimeException(e);
}
@@ -79,12 +79,12 @@ class TestMockedAPIs {
@Test
void testDeleteModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
String model = "llama2";
try {
doNothing().when(ollamaAPI).deleteModel(model, true);
ollamaAPI.deleteModel(model, true);
verify(ollamaAPI, times(1)).deleteModel(model, true);
doNothing().when(ollama).deleteModel(model, true);
ollama.deleteModel(model, true);
verify(ollama, times(1)).deleteModel(model, true);
} catch (OllamaException e) {
throw new RuntimeException(e);
}
@@ -92,12 +92,12 @@ class TestMockedAPIs {
@Test
void testGetModelDetails() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
String model = "llama2";
try {
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
ollamaAPI.getModelDetails(model);
verify(ollamaAPI, times(1)).getModelDetails(model);
when(ollama.getModelDetails(model)).thenReturn(new ModelDetail());
ollama.getModelDetails(model);
verify(ollama, times(1)).getModelDetails(model);
} catch (OllamaException e) {
throw new RuntimeException(e);
}
@@ -105,16 +105,16 @@ class TestMockedAPIs {
@Test
void testGenerateEmbeddings() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
String model = "llama2";
String prompt = "some prompt text";
try {
OllamaEmbedRequest m = new OllamaEmbedRequest();
m.setModel(model);
m.setInput(List.of(prompt));
when(ollamaAPI.embed(m)).thenReturn(new OllamaEmbedResult());
ollamaAPI.embed(m);
verify(ollamaAPI, times(1)).embed(m);
when(ollama.embed(m)).thenReturn(new OllamaEmbedResult());
ollama.embed(m);
verify(ollama, times(1)).embed(m);
} catch (OllamaException e) {
throw new RuntimeException(e);
}
@@ -122,14 +122,14 @@ class TestMockedAPIs {
@Test
void testEmbed() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
String model = "llama2";
List<String> inputs = List.of("some prompt text");
try {
OllamaEmbedRequest m = new OllamaEmbedRequest(model, inputs);
when(ollamaAPI.embed(m)).thenReturn(new OllamaEmbedResult());
ollamaAPI.embed(m);
verify(ollamaAPI, times(1)).embed(m);
when(ollama.embed(m)).thenReturn(new OllamaEmbedResult());
ollama.embed(m);
verify(ollama, times(1)).embed(m);
} catch (OllamaException e) {
throw new RuntimeException(e);
}
@@ -137,14 +137,14 @@ class TestMockedAPIs {
@Test
void testEmbedWithEmbedRequestModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
String model = "llama2";
List<String> inputs = List.of("some prompt text");
try {
when(ollamaAPI.embed(new OllamaEmbedRequest(model, inputs)))
when(ollama.embed(new OllamaEmbedRequest(model, inputs)))
.thenReturn(new OllamaEmbedResult());
ollamaAPI.embed(new OllamaEmbedRequest(model, inputs));
verify(ollamaAPI, times(1)).embed(new OllamaEmbedRequest(model, inputs));
ollama.embed(new OllamaEmbedRequest(model, inputs));
verify(ollama, times(1)).embed(new OllamaEmbedRequest(model, inputs));
} catch (OllamaException e) {
throw new RuntimeException(e);
}
@@ -152,7 +152,7 @@ class TestMockedAPIs {
@Test
void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
String model = "llama2";
String prompt = "some prompt text";
OllamaGenerateStreamObserver observer = new OllamaGenerateStreamObserver(null, null);
@@ -165,10 +165,9 @@ class TestMockedAPIs {
.withThink(false)
.withStreaming(false)
.build();
when(ollamaAPI.generate(request, observer))
.thenReturn(new OllamaResult("", "", 0, 200));
ollamaAPI.generate(request, observer);
verify(ollamaAPI, times(1)).generate(request, observer);
when(ollama.generate(request, observer)).thenReturn(new OllamaResult("", "", 0, 200));
ollama.generate(request, observer);
verify(ollama, times(1)).generate(request, observer);
} catch (OllamaException e) {
throw new RuntimeException(e);
}
@@ -176,7 +175,7 @@ class TestMockedAPIs {
@Test
void testAskWithImageFiles() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
String model = "llama2";
String prompt = "some prompt text";
try {
@@ -192,9 +191,9 @@ class TestMockedAPIs {
.withFormat(null)
.build();
OllamaGenerateStreamObserver handler = null;
when(ollamaAPI.generate(request, handler)).thenReturn(new OllamaResult("", "", 0, 200));
ollamaAPI.generate(request, handler);
verify(ollamaAPI, times(1)).generate(request, handler);
when(ollama.generate(request, handler)).thenReturn(new OllamaResult("", "", 0, 200));
ollama.generate(request, handler);
verify(ollama, times(1)).generate(request, handler);
} catch (Exception e) {
throw new RuntimeException(e);
}
@@ -202,7 +201,7 @@ class TestMockedAPIs {
@Test
void testAskWithImageURLs() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
String model = "llama2";
String prompt = "some prompt text";
try {
@@ -218,9 +217,9 @@ class TestMockedAPIs {
.withFormat(null)
.build();
OllamaGenerateStreamObserver handler = null;
when(ollamaAPI.generate(request, handler)).thenReturn(new OllamaResult("", "", 0, 200));
ollamaAPI.generate(request, handler);
verify(ollamaAPI, times(1)).generate(request, handler);
when(ollama.generate(request, handler)).thenReturn(new OllamaResult("", "", 0, 200));
ollama.generate(request, handler);
verify(ollama, times(1)).generate(request, handler);
} catch (OllamaException e) {
throw new RuntimeException(e);
} catch (IOException e) {
@@ -230,56 +229,55 @@ class TestMockedAPIs {
@Test
void testAskAsync() throws OllamaException {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
String model = "llama2";
String prompt = "some prompt text";
when(ollamaAPI.generateAsync(model, prompt, false, false))
when(ollama.generateAsync(model, prompt, false, false))
.thenReturn(new OllamaAsyncResultStreamer(null, null, 3));
ollamaAPI.generateAsync(model, prompt, false, false);
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false, false);
ollama.generateAsync(model, prompt, false, false);
verify(ollama, times(1)).generateAsync(model, prompt, false, false);
}
@Test
void testAddCustomRole() {
OllamaAPI ollamaAPI = mock(OllamaAPI.class);
Ollama ollama = mock(Ollama.class);
String roleName = "custom-role";
OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
when(ollamaAPI.addCustomRole(roleName)).thenReturn(expectedRole);
OllamaChatMessageRole customRole = ollamaAPI.addCustomRole(roleName);
when(ollama.addCustomRole(roleName)).thenReturn(expectedRole);
OllamaChatMessageRole customRole = ollama.addCustomRole(roleName);
assertEquals(expectedRole, customRole);
verify(ollamaAPI, times(1)).addCustomRole(roleName);
verify(ollama, times(1)).addCustomRole(roleName);
}
@Test
void testListRoles() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
Ollama ollama = Mockito.mock(Ollama.class);
OllamaChatMessageRole role1 = OllamaChatMessageRole.newCustomRole("role1");
OllamaChatMessageRole role2 = OllamaChatMessageRole.newCustomRole("role2");
List<OllamaChatMessageRole> expectedRoles = List.of(role1, role2);
when(ollamaAPI.listRoles()).thenReturn(expectedRoles);
List<OllamaChatMessageRole> actualRoles = ollamaAPI.listRoles();
when(ollama.listRoles()).thenReturn(expectedRoles);
List<OllamaChatMessageRole> actualRoles = ollama.listRoles();
assertEquals(expectedRoles, actualRoles);
verify(ollamaAPI, times(1)).listRoles();
verify(ollama, times(1)).listRoles();
}
@Test
void testGetRoleNotFound() {
OllamaAPI ollamaAPI = mock(OllamaAPI.class);
Ollama ollama = mock(Ollama.class);
String roleName = "non-existing-role";
try {
when(ollamaAPI.getRole(roleName))
.thenThrow(new RoleNotFoundException("Role not found"));
when(ollama.getRole(roleName)).thenThrow(new RoleNotFoundException("Role not found"));
} catch (RoleNotFoundException exception) {
throw new RuntimeException("Failed to run test: testGetRoleNotFound");
}
try {
ollamaAPI.getRole(roleName);
ollama.getRole(roleName);
fail("Expected RoleNotFoundException not thrown");
} catch (RoleNotFoundException exception) {
assertEquals("Role not found", exception.getMessage());
}
try {
verify(ollamaAPI, times(1)).getRole(roleName);
verify(ollama, times(1)).getRole(roleName);
} catch (RoleNotFoundException exception) {
throw new RuntimeException("Failed to run test: testGetRoleNotFound");
}
@@ -287,18 +285,18 @@ class TestMockedAPIs {
@Test
void testGetRoleFound() {
OllamaAPI ollamaAPI = mock(OllamaAPI.class);
Ollama ollama = mock(Ollama.class);
String roleName = "existing-role";
OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
try {
when(ollamaAPI.getRole(roleName)).thenReturn(expectedRole);
when(ollama.getRole(roleName)).thenReturn(expectedRole);
} catch (RoleNotFoundException exception) {
throw new RuntimeException("Failed to run test: testGetRoleFound");
}
try {
OllamaChatMessageRole actualRole = ollamaAPI.getRole(roleName);
OllamaChatMessageRole actualRole = ollama.getRole(roleName);
assertEquals(expectedRole, actualRole);
verify(ollamaAPI, times(1)).getRole(roleName);
verify(ollama, times(1)).getRole(roleName);
} catch (RoleNotFoundException exception) {
throw new RuntimeException("Failed to run test: testGetRoleFound");
}