forked from Mirror/ollama4j
		
	Custom roles support
Adds support for custom roles using `OllamaChatMessageRole`
This commit is contained in:
		@@ -1,12 +1,10 @@
 | 
				
			|||||||
package io.github.ollama4j;
 | 
					package io.github.ollama4j;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
					import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
				
			||||||
 | 
					import io.github.ollama4j.exceptions.RoleNotFoundException;
 | 
				
			||||||
import io.github.ollama4j.exceptions.ToolInvocationException;
 | 
					import io.github.ollama4j.exceptions.ToolInvocationException;
 | 
				
			||||||
import io.github.ollama4j.exceptions.ToolNotFoundException;
 | 
					import io.github.ollama4j.exceptions.ToolNotFoundException;
 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
 | 
					import io.github.ollama4j.models.chat.*;
 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.chat.OllamaChatResult;
 | 
					 | 
				
			||||||
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
 | 
					import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
 | 
				
			||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
 | 
					import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
 | 
				
			||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
					import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
				
			||||||
@@ -166,7 +164,6 @@ public class OllamaAPI {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Pull a model on the Ollama server from the list of <a
 | 
					     * Pull a model on the Ollama server from the list of <a
 | 
				
			||||||
     * href="https://ollama.ai/library">available models</a>.
 | 
					     * href="https://ollama.ai/library">available models</a>.
 | 
				
			||||||
@@ -471,7 +468,6 @@ public class OllamaAPI {
 | 
				
			|||||||
        return toolResult;
 | 
					        return toolResult;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Generate response for a question to a model running on Ollama server and get a callback handle
 | 
					     * Generate response for a question to a model running on Ollama server and get a callback handle
 | 
				
			||||||
     * that can be used to check for status and get the response from the model later. This would be
 | 
					     * that can be used to check for status and get the response from the model later. This would be
 | 
				
			||||||
@@ -570,7 +566,6 @@ public class OllamaAPI {
 | 
				
			|||||||
        return generateWithImageURLs(model, prompt, imageURLs, options, null);
 | 
					        return generateWithImageURLs(model, prompt, imageURLs, options, null);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api
 | 
					     * Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api
 | 
				
			||||||
     * 'api/chat'.
 | 
					     * 'api/chat'.
 | 
				
			||||||
@@ -639,6 +634,25 @@ public class OllamaAPI {
 | 
				
			|||||||
        toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
 | 
					        toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * @param roleName - Custom role to be added
 | 
				
			||||||
 | 
					     * @return OllamaChatMessageRole
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public OllamaChatMessageRole addCustomRole(String roleName) {
 | 
				
			||||||
 | 
					        return OllamaChatMessageRole.newCustomRole(roleName);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * @return - available roles
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public List<OllamaChatMessageRole> listRoles() {
 | 
				
			||||||
 | 
					        return OllamaChatMessageRole.getRoles();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
 | 
				
			||||||
 | 
					        return OllamaChatMessageRole.getRole(roleName);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // technical private methods //
 | 
					    // technical private methods //
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private static String encodeFileToBase64(File file) throws IOException {
 | 
					    private static String encodeFileToBase64(File file) throws IOException {
 | 
				
			||||||
@@ -694,7 +708,6 @@ public class OllamaAPI {
 | 
				
			|||||||
        return basicAuth != null;
 | 
					        return basicAuth != null;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
 | 
					    private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
 | 
				
			||||||
        try {
 | 
					        try {
 | 
				
			||||||
            String methodName = toolFunctionCallSpec.getName();
 | 
					            String methodName = toolFunctionCallSpec.getName();
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,8 @@
 | 
				
			|||||||
 | 
					package io.github.ollama4j.exceptions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					public class RoleNotFoundException extends Exception {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public RoleNotFoundException(String s) {
 | 
				
			||||||
 | 
					        super(s);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,20 +1,53 @@
 | 
				
			|||||||
package io.github.ollama4j.models.chat;
 | 
					package io.github.ollama4j.models.chat;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import com.fasterxml.jackson.annotation.JsonValue;
 | 
					import com.fasterxml.jackson.annotation.JsonValue;
 | 
				
			||||||
 | 
					import io.github.ollama4j.exceptions.RoleNotFoundException;
 | 
				
			||||||
 | 
					import lombok.Getter;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.ArrayList;
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Defines the possible Chat Message roles.
 | 
					 * Defines the possible Chat Message roles.
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
public enum OllamaChatMessageRole {
 | 
					@Getter
 | 
				
			||||||
    SYSTEM("system"),
 | 
					public class OllamaChatMessageRole {
 | 
				
			||||||
    USER("user"),
 | 
					    private static final List<OllamaChatMessageRole> ROLES = new ArrayList<>();
 | 
				
			||||||
    ASSISTANT("assistant"),
 | 
					
 | 
				
			||||||
    TOOL("tool");
 | 
					    public static final OllamaChatMessageRole SYSTEM = new OllamaChatMessageRole("system");
 | 
				
			||||||
 | 
					    public static final OllamaChatMessageRole USER = new OllamaChatMessageRole("user");
 | 
				
			||||||
 | 
					    public static final OllamaChatMessageRole ASSISTANT = new OllamaChatMessageRole("assistant");
 | 
				
			||||||
 | 
					    public static final OllamaChatMessageRole TOOL = new OllamaChatMessageRole("tool");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @JsonValue
 | 
					    @JsonValue
 | 
				
			||||||
    private String roleName;
 | 
					    private final String roleName;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private OllamaChatMessageRole(String roleName) {
 | 
					    private OllamaChatMessageRole(String roleName) {
 | 
				
			||||||
        this.roleName = roleName;
 | 
					        this.roleName = roleName;
 | 
				
			||||||
 | 
					        ROLES.add(this);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public static OllamaChatMessageRole newCustomRole(String roleName) {
 | 
				
			||||||
 | 
					        OllamaChatMessageRole customRole = new OllamaChatMessageRole(roleName);
 | 
				
			||||||
 | 
					        ROLES.add(customRole);
 | 
				
			||||||
 | 
					        return customRole;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public static List<OllamaChatMessageRole> getRoles() {
 | 
				
			||||||
 | 
					        return new ArrayList<>(ROLES);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public static OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
 | 
				
			||||||
 | 
					        for (OllamaChatMessageRole role : ROLES) {
 | 
				
			||||||
 | 
					            if (role.roleName.equals(roleName)) {
 | 
				
			||||||
 | 
					                return role;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        throw new RoleNotFoundException("Invalid role name: " + roleName);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Override
 | 
				
			||||||
 | 
					    public String toString() {
 | 
				
			||||||
 | 
					        return roleName;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,8 +12,7 @@ public class OllamaChatResult extends OllamaResult{
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    private List<OllamaChatMessage> chatHistory;
 | 
					    private List<OllamaChatMessage> chatHistory;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public OllamaChatResult(String response, long responseTime, int httpStatusCode,
 | 
					    public OllamaChatResult(String response, long responseTime, int httpStatusCode, List<OllamaChatMessage> chatHistory) {
 | 
				
			||||||
            List<OllamaChatMessage> chatHistory) {
 | 
					 | 
				
			||||||
        super(response, responseTime, httpStatusCode);
 | 
					        super(response, responseTime, httpStatusCode);
 | 
				
			||||||
        this.chatHistory = chatHistory;
 | 
					        this.chatHistory = chatHistory;
 | 
				
			||||||
        appendAnswerToChatHistory(response);
 | 
					        appendAnswerToChatHistory(response);
 | 
				
			||||||
@@ -27,6 +26,4 @@ public class OllamaChatResult extends OllamaResult{
 | 
				
			|||||||
        OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer);
 | 
					        OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer);
 | 
				
			||||||
        this.chatHistory.add(assistantMessage);
 | 
					        this.chatHistory.add(assistantMessage);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,10 @@ package io.github.ollama4j.unittests;
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import io.github.ollama4j.OllamaAPI;
 | 
					import io.github.ollama4j.OllamaAPI;
 | 
				
			||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
					import io.github.ollama4j.exceptions.OllamaBaseException;
 | 
				
			||||||
 | 
					import io.github.ollama4j.exceptions.RoleNotFoundException;
 | 
				
			||||||
 | 
					import io.github.ollama4j.models.chat.OllamaChatMessageRole;
 | 
				
			||||||
 | 
					import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
 | 
				
			||||||
 | 
					import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
 | 
				
			||||||
import io.github.ollama4j.models.response.ModelDetail;
 | 
					import io.github.ollama4j.models.response.ModelDetail;
 | 
				
			||||||
import io.github.ollama4j.models.response.OllamaAsyncResultStreamer;
 | 
					import io.github.ollama4j.models.response.OllamaAsyncResultStreamer;
 | 
				
			||||||
import io.github.ollama4j.models.response.OllamaResult;
 | 
					import io.github.ollama4j.models.response.OllamaResult;
 | 
				
			||||||
@@ -14,7 +18,9 @@ import java.io.IOException;
 | 
				
			|||||||
import java.net.URISyntaxException;
 | 
					import java.net.URISyntaxException;
 | 
				
			||||||
import java.util.ArrayList;
 | 
					import java.util.ArrayList;
 | 
				
			||||||
import java.util.Collections;
 | 
					import java.util.Collections;
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import static org.junit.jupiter.api.Assertions.*;
 | 
				
			||||||
import static org.mockito.Mockito.*;
 | 
					import static org.mockito.Mockito.*;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestMockedAPIs {
 | 
					class TestMockedAPIs {
 | 
				
			||||||
@@ -97,6 +103,34 @@ class TestMockedAPIs {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    void testEmbed() {
 | 
				
			||||||
 | 
					        OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
 | 
				
			||||||
 | 
					        String model = OllamaModelType.LLAMA2;
 | 
				
			||||||
 | 
					        List<String> inputs = List.of("some prompt text");
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            when(ollamaAPI.embed(model, inputs)).thenReturn(new OllamaEmbedResponseModel());
 | 
				
			||||||
 | 
					            ollamaAPI.embed(model, inputs);
 | 
				
			||||||
 | 
					            verify(ollamaAPI, times(1)).embed(model, inputs);
 | 
				
			||||||
 | 
					        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
				
			||||||
 | 
					            throw new RuntimeException(e);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    void testEmbedWithEmbedRequestModel() {
 | 
				
			||||||
 | 
					        OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
 | 
				
			||||||
 | 
					        String model = OllamaModelType.LLAMA2;
 | 
				
			||||||
 | 
					        List<String> inputs = List.of("some prompt text");
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            when(ollamaAPI.embed(new OllamaEmbedRequestModel(model, inputs))).thenReturn(new OllamaEmbedResponseModel());
 | 
				
			||||||
 | 
					            ollamaAPI.embed(new OllamaEmbedRequestModel(model, inputs));
 | 
				
			||||||
 | 
					            verify(ollamaAPI, times(1)).embed(new OllamaEmbedRequestModel(model, inputs));
 | 
				
			||||||
 | 
					        } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
				
			||||||
 | 
					            throw new RuntimeException(e);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    void testAsk() {
 | 
					    void testAsk() {
 | 
				
			||||||
        OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
 | 
					        OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
 | 
				
			||||||
@@ -161,4 +195,68 @@ class TestMockedAPIs {
 | 
				
			|||||||
        ollamaAPI.generateAsync(model, prompt, false);
 | 
					        ollamaAPI.generateAsync(model, prompt, false);
 | 
				
			||||||
        verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
 | 
					        verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    void testAddCustomRole() {
 | 
				
			||||||
 | 
					        OllamaAPI ollamaAPI = mock(OllamaAPI.class);
 | 
				
			||||||
 | 
					        String roleName = "custom-role";
 | 
				
			||||||
 | 
					        OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
 | 
				
			||||||
 | 
					        when(ollamaAPI.addCustomRole(roleName)).thenReturn(expectedRole);
 | 
				
			||||||
 | 
					        OllamaChatMessageRole customRole = ollamaAPI.addCustomRole(roleName);
 | 
				
			||||||
 | 
					        assertEquals(expectedRole, customRole);
 | 
				
			||||||
 | 
					        verify(ollamaAPI, times(1)).addCustomRole(roleName);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    void testListRoles() {
 | 
				
			||||||
 | 
					        OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
 | 
				
			||||||
 | 
					        OllamaChatMessageRole role1 = OllamaChatMessageRole.newCustomRole("role1");
 | 
				
			||||||
 | 
					        OllamaChatMessageRole role2 = OllamaChatMessageRole.newCustomRole("role2");
 | 
				
			||||||
 | 
					        List<OllamaChatMessageRole> expectedRoles = List.of(role1, role2);
 | 
				
			||||||
 | 
					        when(ollamaAPI.listRoles()).thenReturn(expectedRoles);
 | 
				
			||||||
 | 
					        List<OllamaChatMessageRole> actualRoles = ollamaAPI.listRoles();
 | 
				
			||||||
 | 
					        assertEquals(expectedRoles, actualRoles);
 | 
				
			||||||
 | 
					        verify(ollamaAPI, times(1)).listRoles();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    void testGetRoleNotFound() {
 | 
				
			||||||
 | 
					        OllamaAPI ollamaAPI = mock(OllamaAPI.class);
 | 
				
			||||||
 | 
					        String roleName = "non-existing-role";
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            when(ollamaAPI.getRole(roleName)).thenThrow(new RoleNotFoundException("Role not found"));
 | 
				
			||||||
 | 
					        } catch (RoleNotFoundException exception) {
 | 
				
			||||||
 | 
					            throw new RuntimeException("Failed to run test: testGetRoleNotFound");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            ollamaAPI.getRole(roleName);
 | 
				
			||||||
 | 
					            fail("Expected RoleNotFoundException not thrown");
 | 
				
			||||||
 | 
					        } catch (RoleNotFoundException exception) {
 | 
				
			||||||
 | 
					            assertEquals("Role not found", exception.getMessage());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            verify(ollamaAPI, times(1)).getRole(roleName);
 | 
				
			||||||
 | 
					        } catch (RoleNotFoundException exception) {
 | 
				
			||||||
 | 
					            throw new RuntimeException("Failed to run test: testGetRoleNotFound");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    void testGetRoleFound() {
 | 
				
			||||||
 | 
					        OllamaAPI ollamaAPI = mock(OllamaAPI.class);
 | 
				
			||||||
 | 
					        String roleName = "existing-role";
 | 
				
			||||||
 | 
					        OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            when(ollamaAPI.getRole(roleName)).thenReturn(expectedRole);
 | 
				
			||||||
 | 
					        } catch (RoleNotFoundException exception) {
 | 
				
			||||||
 | 
					            throw new RuntimeException("Failed to run test: testGetRoleFound");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
 | 
					            OllamaChatMessageRole actualRole = ollamaAPI.getRole(roleName);
 | 
				
			||||||
 | 
					            assertEquals(expectedRole, actualRole);
 | 
				
			||||||
 | 
					            verify(ollamaAPI, times(1)).getRole(roleName);
 | 
				
			||||||
 | 
					        } catch (RoleNotFoundException exception) {
 | 
				
			||||||
 | 
					            throw new RuntimeException("Failed to run test: testGetRoleFound");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user