forked from Mirror/ollama4j
		
	Custom roles support
Adds support for custom roles using `OllamaChatMessageRole`
This commit is contained in:
		@@ -2,6 +2,10 @@ package io.github.ollama4j.unittests;
 | 
			
		||||
 | 
			
		||||
import io.github.ollama4j.OllamaAPI;
 | 
			
		||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
			
		||||
import io.github.ollama4j.exceptions.RoleNotFoundException;
 | 
			
		||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
 | 
			
		||||
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
 | 
			
		||||
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
 | 
			
		||||
import io.github.ollama4j.models.response.ModelDetail;
 | 
			
		||||
import io.github.ollama4j.models.response.OllamaAsyncResultStreamer;
 | 
			
		||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
			
		||||
@@ -14,7 +18,9 @@ import java.io.IOException;
 | 
			
		||||
import java.net.URISyntaxException;
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
import static org.junit.jupiter.api.Assertions.*;
 | 
			
		||||
import static org.mockito.Mockito.*;
 | 
			
		||||
 | 
			
		||||
class TestMockedAPIs {
 | 
			
		||||
@@ -97,6 +103,34 @@ class TestMockedAPIs {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    void testEmbed() {
 | 
			
		||||
        OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
 | 
			
		||||
        String model = OllamaModelType.LLAMA2;
 | 
			
		||||
        List<String> inputs = List.of("some prompt text");
 | 
			
		||||
        try {
 | 
			
		||||
            when(ollamaAPI.embed(model, inputs)).thenReturn(new OllamaEmbedResponseModel());
 | 
			
		||||
            ollamaAPI.embed(model, inputs);
 | 
			
		||||
            verify(ollamaAPI, times(1)).embed(model, inputs);
 | 
			
		||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
            throw new RuntimeException(e);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    void testEmbedWithEmbedRequestModel() {
 | 
			
		||||
        OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
 | 
			
		||||
        String model = OllamaModelType.LLAMA2;
 | 
			
		||||
        List<String> inputs = List.of("some prompt text");
 | 
			
		||||
        try {
 | 
			
		||||
            when(ollamaAPI.embed(new OllamaEmbedRequestModel(model, inputs))).thenReturn(new OllamaEmbedResponseModel());
 | 
			
		||||
            ollamaAPI.embed(new OllamaEmbedRequestModel(model, inputs));
 | 
			
		||||
            verify(ollamaAPI, times(1)).embed(new OllamaEmbedRequestModel(model, inputs));
 | 
			
		||||
        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
            throw new RuntimeException(e);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    void testAsk() {
 | 
			
		||||
        OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
 | 
			
		||||
@@ -161,4 +195,68 @@ class TestMockedAPIs {
 | 
			
		||||
        ollamaAPI.generateAsync(model, prompt, false);
 | 
			
		||||
        verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    void testAddCustomRole() {
 | 
			
		||||
        OllamaAPI ollamaAPI = mock(OllamaAPI.class);
 | 
			
		||||
        String roleName = "custom-role";
 | 
			
		||||
        OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
 | 
			
		||||
        when(ollamaAPI.addCustomRole(roleName)).thenReturn(expectedRole);
 | 
			
		||||
        OllamaChatMessageRole customRole = ollamaAPI.addCustomRole(roleName);
 | 
			
		||||
        assertEquals(expectedRole, customRole);
 | 
			
		||||
        verify(ollamaAPI, times(1)).addCustomRole(roleName);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    void testListRoles() {
 | 
			
		||||
        OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
 | 
			
		||||
        OllamaChatMessageRole role1 = OllamaChatMessageRole.newCustomRole("role1");
 | 
			
		||||
        OllamaChatMessageRole role2 = OllamaChatMessageRole.newCustomRole("role2");
 | 
			
		||||
        List<OllamaChatMessageRole> expectedRoles = List.of(role1, role2);
 | 
			
		||||
        when(ollamaAPI.listRoles()).thenReturn(expectedRoles);
 | 
			
		||||
        List<OllamaChatMessageRole> actualRoles = ollamaAPI.listRoles();
 | 
			
		||||
        assertEquals(expectedRoles, actualRoles);
 | 
			
		||||
        verify(ollamaAPI, times(1)).listRoles();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    void testGetRoleNotFound() {
 | 
			
		||||
        OllamaAPI ollamaAPI = mock(OllamaAPI.class);
 | 
			
		||||
        String roleName = "non-existing-role";
 | 
			
		||||
        try {
 | 
			
		||||
            when(ollamaAPI.getRole(roleName)).thenThrow(new RoleNotFoundException("Role not found"));
 | 
			
		||||
        } catch (RoleNotFoundException exception) {
 | 
			
		||||
            throw new RuntimeException("Failed to run test: testGetRoleNotFound");
 | 
			
		||||
        }
 | 
			
		||||
        try {
 | 
			
		||||
            ollamaAPI.getRole(roleName);
 | 
			
		||||
            fail("Expected RoleNotFoundException not thrown");
 | 
			
		||||
        } catch (RoleNotFoundException exception) {
 | 
			
		||||
            assertEquals("Role not found", exception.getMessage());
 | 
			
		||||
        }
 | 
			
		||||
        try {
 | 
			
		||||
            verify(ollamaAPI, times(1)).getRole(roleName);
 | 
			
		||||
        } catch (RoleNotFoundException exception) {
 | 
			
		||||
            throw new RuntimeException("Failed to run test: testGetRoleNotFound");
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    void testGetRoleFound() {
 | 
			
		||||
        OllamaAPI ollamaAPI = mock(OllamaAPI.class);
 | 
			
		||||
        String roleName = "existing-role";
 | 
			
		||||
        OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
 | 
			
		||||
        try {
 | 
			
		||||
            when(ollamaAPI.getRole(roleName)).thenReturn(expectedRole);
 | 
			
		||||
        } catch (RoleNotFoundException exception) {
 | 
			
		||||
            throw new RuntimeException("Failed to run test: testGetRoleFound");
 | 
			
		||||
        }
 | 
			
		||||
        try {
 | 
			
		||||
            OllamaChatMessageRole actualRole = ollamaAPI.getRole(roleName);
 | 
			
		||||
            assertEquals(expectedRole, actualRole);
 | 
			
		||||
            verify(ollamaAPI, times(1)).getRole(roleName);
 | 
			
		||||
        } catch (RoleNotFoundException exception) {
 | 
			
		||||
            throw new RuntimeException("Failed to run test: testGetRoleFound");
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user