Custom roles support

Adds support for custom roles using `OllamaChatMessageRole`
This commit is contained in:
Amith Koujalgi 2024-10-31 16:15:21 +05:30
parent bedfec6bf9
commit 921f745435
No known key found for this signature in database
GPG Key ID: E29A37746AF94B70
5 changed files with 171 additions and 22 deletions

View File

@ -1,12 +1,10 @@
package io.github.ollama4j; package io.github.ollama4j;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.exceptions.RoleNotFoundException;
import io.github.ollama4j.exceptions.ToolInvocationException; import io.github.ollama4j.exceptions.ToolInvocationException;
import io.github.ollama4j.exceptions.ToolNotFoundException; import io.github.ollama4j.exceptions.ToolNotFoundException;
import io.github.ollama4j.models.chat.OllamaChatMessage; import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.chat.OllamaChatRequest;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
@ -166,7 +164,6 @@ public class OllamaAPI {
} }
} }
/** /**
* Pull a model on the Ollama server from the list of <a * Pull a model on the Ollama server from the list of <a
* href="https://ollama.ai/library">available models</a>. * href="https://ollama.ai/library">available models</a>.
@ -471,7 +468,6 @@ public class OllamaAPI {
return toolResult; return toolResult;
} }
/** /**
* Generate response for a question to a model running on Ollama server and get a callback handle * Generate response for a question to a model running on Ollama server and get a callback handle
* that can be used to check for status and get the response from the model later. This would be * that can be used to check for status and get the response from the model later. This would be
@ -570,7 +566,6 @@ public class OllamaAPI {
return generateWithImageURLs(model, prompt, imageURLs, options, null); return generateWithImageURLs(model, prompt, imageURLs, options, null);
} }
/** /**
* Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api * Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api
* 'api/chat'. * 'api/chat'.
@ -639,6 +634,25 @@ public class OllamaAPI {
toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition()); toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
} }
/**
* @param roleName - Custom role to be added
* @return OllamaChatMessageRole
*/
public OllamaChatMessageRole addCustomRole(String roleName) {
return OllamaChatMessageRole.newCustomRole(roleName);
}
/**
* @return - available roles
*/
public List<OllamaChatMessageRole> listRoles() {
return OllamaChatMessageRole.getRoles();
}
public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
return OllamaChatMessageRole.getRole(roleName);
}
// technical private methods // // technical private methods //
private static String encodeFileToBase64(File file) throws IOException { private static String encodeFileToBase64(File file) throws IOException {
@ -694,7 +708,6 @@ public class OllamaAPI {
return basicAuth != null; return basicAuth != null;
} }
private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException { private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
try { try {
String methodName = toolFunctionCallSpec.getName(); String methodName = toolFunctionCallSpec.getName();

View File

@ -0,0 +1,8 @@
package io.github.ollama4j.exceptions;
public class RoleNotFoundException extends Exception {
public RoleNotFoundException(String s) {
super(s);
}
}

View File

@ -1,20 +1,53 @@
package io.github.ollama4j.models.chat; package io.github.ollama4j.models.chat;
import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.annotation.JsonValue;
import io.github.ollama4j.exceptions.RoleNotFoundException;
import lombok.Getter;
import java.util.ArrayList;
import java.util.List;
/** /**
* Defines the possible Chat Message roles. * Defines the possible Chat Message roles.
*/ */
public enum OllamaChatMessageRole { @Getter
SYSTEM("system"), public class OllamaChatMessageRole {
USER("user"), private static final List<OllamaChatMessageRole> ROLES = new ArrayList<>();
ASSISTANT("assistant"),
TOOL("tool"); public static final OllamaChatMessageRole SYSTEM = new OllamaChatMessageRole("system");
public static final OllamaChatMessageRole USER = new OllamaChatMessageRole("user");
public static final OllamaChatMessageRole ASSISTANT = new OllamaChatMessageRole("assistant");
public static final OllamaChatMessageRole TOOL = new OllamaChatMessageRole("tool");
@JsonValue @JsonValue
private String roleName; private final String roleName;
private OllamaChatMessageRole(String roleName) { private OllamaChatMessageRole(String roleName) {
this.roleName = roleName; this.roleName = roleName;
ROLES.add(this);
}
public static OllamaChatMessageRole newCustomRole(String roleName) {
OllamaChatMessageRole customRole = new OllamaChatMessageRole(roleName);
ROLES.add(customRole);
return customRole;
}
public static List<OllamaChatMessageRole> getRoles() {
return new ArrayList<>(ROLES);
}
public static OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
for (OllamaChatMessageRole role : ROLES) {
if (role.roleName.equals(roleName)) {
return role;
}
}
throw new RoleNotFoundException("Invalid role name: " + roleName);
}
@Override
public String toString() {
return roleName;
} }
} }

View File

@ -12,8 +12,7 @@ public class OllamaChatResult extends OllamaResult{
private List<OllamaChatMessage> chatHistory; private List<OllamaChatMessage> chatHistory;
public OllamaChatResult(String response, long responseTime, int httpStatusCode, public OllamaChatResult(String response, long responseTime, int httpStatusCode, List<OllamaChatMessage> chatHistory) {
List<OllamaChatMessage> chatHistory) {
super(response, responseTime, httpStatusCode); super(response, responseTime, httpStatusCode);
this.chatHistory = chatHistory; this.chatHistory = chatHistory;
appendAnswerToChatHistory(response); appendAnswerToChatHistory(response);
@ -27,6 +26,4 @@ public class OllamaChatResult extends OllamaResult{
OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer); OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer);
this.chatHistory.add(assistantMessage); this.chatHistory.add(assistantMessage);
} }
} }

View File

@ -2,6 +2,10 @@ package io.github.ollama4j.unittests;
import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.exceptions.RoleNotFoundException;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import io.github.ollama4j.models.response.ModelDetail; import io.github.ollama4j.models.response.ModelDetail;
import io.github.ollama4j.models.response.OllamaAsyncResultStreamer; import io.github.ollama4j.models.response.OllamaAsyncResultStreamer;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
@ -14,7 +18,9 @@ import java.io.IOException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
class TestMockedAPIs { class TestMockedAPIs {
@ -97,6 +103,34 @@ class TestMockedAPIs {
} }
} }
@Test
void testEmbed() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
List<String> inputs = List.of("some prompt text");
try {
when(ollamaAPI.embed(model, inputs)).thenReturn(new OllamaEmbedResponseModel());
ollamaAPI.embed(model, inputs);
verify(ollamaAPI, times(1)).embed(model, inputs);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
void testEmbedWithEmbedRequestModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
List<String> inputs = List.of("some prompt text");
try {
when(ollamaAPI.embed(new OllamaEmbedRequestModel(model, inputs))).thenReturn(new OllamaEmbedResponseModel());
ollamaAPI.embed(new OllamaEmbedRequestModel(model, inputs));
verify(ollamaAPI, times(1)).embed(new OllamaEmbedRequestModel(model, inputs));
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test @Test
void testAsk() { void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
@ -161,4 +195,68 @@ class TestMockedAPIs {
ollamaAPI.generateAsync(model, prompt, false); ollamaAPI.generateAsync(model, prompt, false);
verify(ollamaAPI, times(1)).generateAsync(model, prompt, false); verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
} }
@Test
void testAddCustomRole() {
OllamaAPI ollamaAPI = mock(OllamaAPI.class);
String roleName = "custom-role";
OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
when(ollamaAPI.addCustomRole(roleName)).thenReturn(expectedRole);
OllamaChatMessageRole customRole = ollamaAPI.addCustomRole(roleName);
assertEquals(expectedRole, customRole);
verify(ollamaAPI, times(1)).addCustomRole(roleName);
}
@Test
void testListRoles() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
OllamaChatMessageRole role1 = OllamaChatMessageRole.newCustomRole("role1");
OllamaChatMessageRole role2 = OllamaChatMessageRole.newCustomRole("role2");
List<OllamaChatMessageRole> expectedRoles = List.of(role1, role2);
when(ollamaAPI.listRoles()).thenReturn(expectedRoles);
List<OllamaChatMessageRole> actualRoles = ollamaAPI.listRoles();
assertEquals(expectedRoles, actualRoles);
verify(ollamaAPI, times(1)).listRoles();
}
@Test
void testGetRoleNotFound() {
OllamaAPI ollamaAPI = mock(OllamaAPI.class);
String roleName = "non-existing-role";
try {
when(ollamaAPI.getRole(roleName)).thenThrow(new RoleNotFoundException("Role not found"));
} catch (RoleNotFoundException exception) {
throw new RuntimeException("Failed to run test: testGetRoleNotFound");
}
try {
ollamaAPI.getRole(roleName);
fail("Expected RoleNotFoundException not thrown");
} catch (RoleNotFoundException exception) {
assertEquals("Role not found", exception.getMessage());
}
try {
verify(ollamaAPI, times(1)).getRole(roleName);
} catch (RoleNotFoundException exception) {
throw new RuntimeException("Failed to run test: testGetRoleNotFound");
}
}
@Test
void testGetRoleFound() {
OllamaAPI ollamaAPI = mock(OllamaAPI.class);
String roleName = "existing-role";
OllamaChatMessageRole expectedRole = OllamaChatMessageRole.newCustomRole(roleName);
try {
when(ollamaAPI.getRole(roleName)).thenReturn(expectedRole);
} catch (RoleNotFoundException exception) {
throw new RuntimeException("Failed to run test: testGetRoleFound");
}
try {
OllamaChatMessageRole actualRole = ollamaAPI.getRole(roleName);
assertEquals(expectedRole, actualRole);
verify(ollamaAPI, times(1)).getRole(roleName);
} catch (RoleNotFoundException exception) {
throw new RuntimeException("Failed to run test: testGetRoleFound");
}
}
} }