diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index e636471..3d3e05f 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -1,12 +1,10 @@ package io.github.ollama4j; import io.github.ollama4j.exceptions.OllamaBaseException; +import io.github.ollama4j.exceptions.RoleNotFoundException; import io.github.ollama4j.exceptions.ToolInvocationException; import io.github.ollama4j.exceptions.ToolNotFoundException; -import io.github.ollama4j.models.chat.OllamaChatMessage; -import io.github.ollama4j.models.chat.OllamaChatRequest; -import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; -import io.github.ollama4j.models.chat.OllamaChatResult; +import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; @@ -166,7 +164,6 @@ public class OllamaAPI { } } - /** * Pull a model on the Ollama server from the list of available models. @@ -471,7 +468,6 @@ public class OllamaAPI { return toolResult; } - /** * Generate response for a question to a model running on Ollama server and get a callback handle * that can be used to check for status and get the response from the model later. This would be @@ -570,7 +566,6 @@ public class OllamaAPI { return generateWithImageURLs(model, prompt, imageURLs, options, null); } - /** * Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api * 'api/chat'. @@ -639,6 +634,25 @@ public class OllamaAPI { toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition()); } + /** + * @param roleName - Custom role to be added + * @return OllamaChatMessageRole + */ + public OllamaChatMessageRole addCustomRole(String roleName) { + return OllamaChatMessageRole.newCustomRole(roleName); + } + + /** + * @return - available roles + */ + public List listRoles() { + return OllamaChatMessageRole.getRoles(); + } + + public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException { + return OllamaChatMessageRole.getRole(roleName); + } + // technical private methods // private static String encodeFileToBase64(File file) throws IOException { @@ -694,7 +708,6 @@ public class OllamaAPI { return basicAuth != null; } - private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException { try { String methodName = toolFunctionCallSpec.getName(); diff --git a/src/main/java/io/github/ollama4j/exceptions/RoleNotFoundException.java b/src/main/java/io/github/ollama4j/exceptions/RoleNotFoundException.java new file mode 100644 index 0000000..a7d1d18 --- /dev/null +++ b/src/main/java/io/github/ollama4j/exceptions/RoleNotFoundException.java @@ -0,0 +1,8 @@ +package io.github.ollama4j.exceptions; + +public class RoleNotFoundException extends Exception { + + public RoleNotFoundException(String s) { + super(s); + } +} diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessageRole.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessageRole.java index 1eb4f30..5432cf1 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessageRole.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessageRole.java @@ -1,20 +1,53 @@ package io.github.ollama4j.models.chat; import com.fasterxml.jackson.annotation.JsonValue; +import io.github.ollama4j.exceptions.RoleNotFoundException; +import lombok.Getter; + +import java.util.ArrayList; +import java.util.List; /** * Defines the possible Chat Message roles. */ -public enum OllamaChatMessageRole { - SYSTEM("system"), - USER("user"), - ASSISTANT("assistant"), - TOOL("tool"); +@Getter +public class OllamaChatMessageRole { + private static final List ROLES = new ArrayList<>(); + + public static final OllamaChatMessageRole SYSTEM = new OllamaChatMessageRole("system"); + public static final OllamaChatMessageRole USER = new OllamaChatMessageRole("user"); + public static final OllamaChatMessageRole ASSISTANT = new OllamaChatMessageRole("assistant"); + public static final OllamaChatMessageRole TOOL = new OllamaChatMessageRole("tool"); @JsonValue - private String roleName; + private final String roleName; - private OllamaChatMessageRole(String roleName){ + private OllamaChatMessageRole(String roleName) { this.roleName = roleName; + ROLES.add(this); + } + + public static OllamaChatMessageRole newCustomRole(String roleName) { + OllamaChatMessageRole customRole = new OllamaChatMessageRole(roleName); + ROLES.add(customRole); + return customRole; + } + + public static List getRoles() { + return new ArrayList<>(ROLES); + } + + public static OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException { + for (OllamaChatMessageRole role : ROLES) { + if (role.roleName.equals(roleName)) { + return role; + } + } + throw new RoleNotFoundException("Invalid role name: " + roleName); + } + + @Override + public String toString() { + return roleName; } } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java index b105e81..b9616f3 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java @@ -8,12 +8,11 @@ import io.github.ollama4j.models.response.OllamaResult; * Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the * {@link OllamaChatMessageRole#ASSISTANT} role. */ -public class OllamaChatResult extends OllamaResult{ +public class OllamaChatResult extends OllamaResult { private List chatHistory; - public OllamaChatResult(String response, long responseTime, int httpStatusCode, - List chatHistory) { + public OllamaChatResult(String response, long responseTime, int httpStatusCode, List chatHistory) { super(response, responseTime, httpStatusCode); this.chatHistory = chatHistory; appendAnswerToChatHistory(response); @@ -21,12 +20,10 @@ public class OllamaChatResult extends OllamaResult{ public List getChatHistory() { return chatHistory; - } + } - private void appendAnswerToChatHistory(String answer){ + private void appendAnswerToChatHistory(String answer) { OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer); this.chatHistory.add(assistantMessage); } - - } diff --git a/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java index 921ccf7..c775a89 100644 --- a/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java +++ b/src/test/java/io/github/ollama4j/unittests/TestMockedAPIs.java @@ -2,6 +2,10 @@ package io.github.ollama4j.unittests; import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.exceptions.OllamaBaseException; +import io.github.ollama4j.exceptions.RoleNotFoundException; +import io.github.ollama4j.models.chat.OllamaChatMessageRole; +import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; +import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.response.ModelDetail; import io.github.ollama4j.models.response.OllamaAsyncResultStreamer; import io.github.ollama4j.models.response.OllamaResult; @@ -14,7 +18,9 @@ import java.io.IOException; import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Collections; +import java.util.List; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; class TestMockedAPIs { @@ -97,6 +103,34 @@ class TestMockedAPIs { } } + @Test + void testEmbed() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + List inputs = List.of("some prompt text"); + try { + when(ollamaAPI.embed(model, inputs)).thenReturn(new OllamaEmbedResponseModel()); + ollamaAPI.embed(model, inputs); + verify(ollamaAPI, times(1)).embed(model, inputs); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Test + void testEmbedWithEmbedRequestModel() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + List inputs = List.of("some prompt text"); + try { + when(ollamaAPI.embed(new OllamaEmbedRequestModel(model, inputs))).thenReturn(new OllamaEmbedResponseModel()); + ollamaAPI.embed(new OllamaEmbedRequestModel(model, inputs)); + verify(ollamaAPI, times(1)).embed(new OllamaEmbedRequestModel(model, inputs)); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + @Test void testAsk() { OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); @@ -161,4 +195,68 @@ class TestMockedAPIs { ollamaAPI.generateAsync(model, prompt, false); verify(ollamaAPI, times(1)).generateAsync(model, prompt, false); } + + @Test + void testAddCustomRole() { + OllamaAPI ollamaAPI = mock(OllamaAPI.class); + String roleName = "custom-role"; + OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName); + when(ollamaAPI.addCustomRole(roleName)).thenReturn(expectedRole); + OllamaChatMessageRole customRole = ollamaAPI.addCustomRole(roleName); + assertEquals(expectedRole, customRole); + verify(ollamaAPI, times(1)).addCustomRole(roleName); + } + + @Test + void testListRoles() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + OllamaChatMessageRole role1 = OllamaChatMessageRole.newCustomRole("role1"); + OllamaChatMessageRole role2 = OllamaChatMessageRole.newCustomRole("role2"); + List expectedRoles = List.of(role1, role2); + when(ollamaAPI.listRoles()).thenReturn(expectedRoles); + List actualRoles = ollamaAPI.listRoles(); + assertEquals(expectedRoles, actualRoles); + verify(ollamaAPI, times(1)).listRoles(); + } + + @Test + void testGetRoleNotFound() { + OllamaAPI ollamaAPI = mock(OllamaAPI.class); + String roleName = "non-existing-role"; + try { + when(ollamaAPI.getRole(roleName)).thenThrow(new RoleNotFoundException("Role not found")); + } catch (RoleNotFoundException exception) { + throw new RuntimeException("Failed to run test: testGetRoleNotFound"); + } + try { + ollamaAPI.getRole(roleName); + fail("Expected RoleNotFoundException not thrown"); + } catch (RoleNotFoundException exception) { + assertEquals("Role not found", exception.getMessage()); + } + try { + verify(ollamaAPI, times(1)).getRole(roleName); + } catch (RoleNotFoundException exception) { + throw new RuntimeException("Failed to run test: testGetRoleNotFound"); + } + } + + @Test + void testGetRoleFound() { + OllamaAPI ollamaAPI = mock(OllamaAPI.class); + String roleName = "existing-role"; + OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName); + try { + when(ollamaAPI.getRole(roleName)).thenReturn(expectedRole); + } catch (RoleNotFoundException exception) { + throw new RuntimeException("Failed to run test: testGetRoleFound"); + } + try { + OllamaChatMessageRole actualRole = ollamaAPI.getRole(roleName); + assertEquals(expectedRole, actualRole); + verify(ollamaAPI, times(1)).getRole(roleName); + } catch (RoleNotFoundException exception) { + throw new RuntimeException("Failed to run test: testGetRoleFound"); + } + } }