diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java
index e636471..3d3e05f 100644
--- a/src/main/java/io/github/ollama4j/OllamaAPI.java
+++ b/src/main/java/io/github/ollama4j/OllamaAPI.java
@@ -1,12 +1,10 @@
package io.github.ollama4j;
import io.github.ollama4j.exceptions.OllamaBaseException;
+import io.github.ollama4j.exceptions.RoleNotFoundException;
import io.github.ollama4j.exceptions.ToolInvocationException;
import io.github.ollama4j.exceptions.ToolNotFoundException;
-import io.github.ollama4j.models.chat.OllamaChatMessage;
-import io.github.ollama4j.models.chat.OllamaChatRequest;
-import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
-import io.github.ollama4j.models.chat.OllamaChatResult;
+import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
@@ -166,7 +164,6 @@ public class OllamaAPI {
}
}
-
/**
* Pull a model on the Ollama server from the list of available models.
@@ -471,7 +468,6 @@ public class OllamaAPI {
return toolResult;
}
-
/**
* Generate response for a question to a model running on Ollama server and get a callback handle
* that can be used to check for status and get the response from the model later. This would be
@@ -570,7 +566,6 @@ public class OllamaAPI {
return generateWithImageURLs(model, prompt, imageURLs, options, null);
}
-
/**
* Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api
* 'api/chat'.
@@ -639,6 +634,25 @@ public class OllamaAPI {
toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
}
+ /**
+ * @param roleName - Custom role to be added
+ * @return OllamaChatMessageRole
+ */
+ public OllamaChatMessageRole addCustomRole(String roleName) {
+ return OllamaChatMessageRole.newCustomRole(roleName);
+ }
+
+ /**
+ * @return - available roles
+ */
+ public List listRoles() {
+ return OllamaChatMessageRole.getRoles();
+ }
+
+ public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
+ return OllamaChatMessageRole.getRole(roleName);
+ }
+
// technical private methods //
private static String encodeFileToBase64(File file) throws IOException {
@@ -694,7 +708,6 @@ public class OllamaAPI {
return basicAuth != null;
}
-
private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
try {
String methodName = toolFunctionCallSpec.getName();
diff --git a/src/main/java/io/github/ollama4j/exceptions/RoleNotFoundException.java b/src/main/java/io/github/ollama4j/exceptions/RoleNotFoundException.java
new file mode 100644
index 0000000..a7d1d18
--- /dev/null
+++ b/src/main/java/io/github/ollama4j/exceptions/RoleNotFoundException.java
@@ -0,0 +1,8 @@
+package io.github.ollama4j.exceptions;
+
+public class RoleNotFoundException extends Exception {
+
+ public RoleNotFoundException(String s) {
+ super(s);
+ }
+}
diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessageRole.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessageRole.java
index 1eb4f30..5432cf1 100644
--- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessageRole.java
+++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessageRole.java
@@ -1,20 +1,53 @@
package io.github.ollama4j.models.chat;
import com.fasterxml.jackson.annotation.JsonValue;
+import io.github.ollama4j.exceptions.RoleNotFoundException;
+import lombok.Getter;
+
+import java.util.ArrayList;
+import java.util.List;
/**
* Defines the possible Chat Message roles.
*/
-public enum OllamaChatMessageRole {
- SYSTEM("system"),
- USER("user"),
- ASSISTANT("assistant"),
- TOOL("tool");
+@Getter
+public class OllamaChatMessageRole {
+ private static final List ROLES = new ArrayList<>();
+
+ public static final OllamaChatMessageRole SYSTEM = new OllamaChatMessageRole("system");
+ public static final OllamaChatMessageRole USER = new OllamaChatMessageRole("user");
+ public static final OllamaChatMessageRole ASSISTANT = new OllamaChatMessageRole("assistant");
+ public static final OllamaChatMessageRole TOOL = new OllamaChatMessageRole("tool");
@JsonValue
- private String roleName;
+ private final String roleName;
- private OllamaChatMessageRole(String roleName){
+ private OllamaChatMessageRole(String roleName) {
this.roleName = roleName;
+ ROLES.add(this);
+ }
+
+ public static OllamaChatMessageRole newCustomRole(String roleName) {
+ OllamaChatMessageRole customRole = new OllamaChatMessageRole(roleName);
+ ROLES.add(customRole);
+ return customRole;
+ }
+
+ public static List getRoles() {
+ return new ArrayList<>(ROLES);
+ }
+
+ public static OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
+ for (OllamaChatMessageRole role : ROLES) {
+ if (role.roleName.equals(roleName)) {
+ return role;
+ }
+ }
+ throw new RoleNotFoundException("Invalid role name: " + roleName);
+ }
+
+ @Override
+ public String toString() {
+ return roleName;
}
}
diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java
index b105e81..b9616f3 100644
--- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java
+++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java
@@ -8,12 +8,11 @@ import io.github.ollama4j.models.response.OllamaResult;
* Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
* {@link OllamaChatMessageRole#ASSISTANT} role.
*/
-public class OllamaChatResult extends OllamaResult{
+public class OllamaChatResult extends OllamaResult {
private List chatHistory;
- public OllamaChatResult(String response, long responseTime, int httpStatusCode,
- List chatHistory) {
+ public OllamaChatResult(String response, long responseTime, int httpStatusCode, List chatHistory) {
super(response, responseTime, httpStatusCode);
this.chatHistory = chatHistory;
appendAnswerToChatHistory(response);
@@ -21,12 +20,10 @@ public class OllamaChatResult extends OllamaResult{
public List getChatHistory() {
return chatHistory;
- }
+ }
- private void appendAnswerToChatHistory(String answer){
+ private void appendAnswerToChatHistory(String answer) {
OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer);
this.chatHistory.add(assistantMessage);
}
-
-
}
diff --git a/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java
index 921ccf7..c775a89 100644
--- a/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java
+++ b/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java
@@ -2,6 +2,10 @@ package io.github.ollama4j.unittests;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.OllamaBaseException;
+import io.github.ollama4j.exceptions.RoleNotFoundException;
+import io.github.ollama4j.models.chat.OllamaChatMessageRole;
+import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
+import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import io.github.ollama4j.models.response.ModelDetail;
import io.github.ollama4j.models.response.OllamaAsyncResultStreamer;
import io.github.ollama4j.models.response.OllamaResult;
@@ -14,7 +18,9 @@ import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.List;
+import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
class TestMockedAPIs {
@@ -97,6 +103,34 @@ class TestMockedAPIs {
}
}
+ @Test
+ void testEmbed() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ List inputs = List.of("some prompt text");
+ try {
+ when(ollamaAPI.embed(model, inputs)).thenReturn(new OllamaEmbedResponseModel());
+ ollamaAPI.embed(model, inputs);
+ verify(ollamaAPI, times(1)).embed(model, inputs);
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Test
+ void testEmbedWithEmbedRequestModel() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ List inputs = List.of("some prompt text");
+ try {
+ when(ollamaAPI.embed(new OllamaEmbedRequestModel(model, inputs))).thenReturn(new OllamaEmbedResponseModel());
+ ollamaAPI.embed(new OllamaEmbedRequestModel(model, inputs));
+ verify(ollamaAPI, times(1)).embed(new OllamaEmbedRequestModel(model, inputs));
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
@Test
void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
@@ -161,4 +195,68 @@ class TestMockedAPIs {
ollamaAPI.generateAsync(model, prompt, false);
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
}
+
+ @Test
+ void testAddCustomRole() {
+ OllamaAPI ollamaAPI = mock(OllamaAPI.class);
+ String roleName = "custom-role";
+ OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
+ when(ollamaAPI.addCustomRole(roleName)).thenReturn(expectedRole);
+ OllamaChatMessageRole customRole = ollamaAPI.addCustomRole(roleName);
+ assertEquals(expectedRole, customRole);
+ verify(ollamaAPI, times(1)).addCustomRole(roleName);
+ }
+
+ @Test
+ void testListRoles() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ OllamaChatMessageRole role1 = OllamaChatMessageRole.newCustomRole("role1");
+ OllamaChatMessageRole role2 = OllamaChatMessageRole.newCustomRole("role2");
+ List expectedRoles = List.of(role1, role2);
+ when(ollamaAPI.listRoles()).thenReturn(expectedRoles);
+ List actualRoles = ollamaAPI.listRoles();
+ assertEquals(expectedRoles, actualRoles);
+ verify(ollamaAPI, times(1)).listRoles();
+ }
+
+ @Test
+ void testGetRoleNotFound() {
+ OllamaAPI ollamaAPI = mock(OllamaAPI.class);
+ String roleName = "non-existing-role";
+ try {
+ when(ollamaAPI.getRole(roleName)).thenThrow(new RoleNotFoundException("Role not found"));
+ } catch (RoleNotFoundException exception) {
+ throw new RuntimeException("Failed to run test: testGetRoleNotFound");
+ }
+ try {
+ ollamaAPI.getRole(roleName);
+ fail("Expected RoleNotFoundException not thrown");
+ } catch (RoleNotFoundException exception) {
+ assertEquals("Role not found", exception.getMessage());
+ }
+ try {
+ verify(ollamaAPI, times(1)).getRole(roleName);
+ } catch (RoleNotFoundException exception) {
+ throw new RuntimeException("Failed to run test: testGetRoleNotFound");
+ }
+ }
+
+ @Test
+ void testGetRoleFound() {
+ OllamaAPI ollamaAPI = mock(OllamaAPI.class);
+ String roleName = "existing-role";
+ OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
+ try {
+ when(ollamaAPI.getRole(roleName)).thenReturn(expectedRole);
+ } catch (RoleNotFoundException exception) {
+ throw new RuntimeException("Failed to run test: testGetRoleFound");
+ }
+ try {
+ OllamaChatMessageRole actualRole = ollamaAPI.getRole(roleName);
+ assertEquals(expectedRole, actualRole);
+ verify(ollamaAPI, times(1)).getRole(roleName);
+ } catch (RoleNotFoundException exception) {
+ throw new RuntimeException("Failed to run test: testGetRoleFound");
+ }
+ }
}