Custom roles support

Adds support for custom roles using `OllamaChatMessageRole`
This commit is contained in:
Amith Koujalgi
2024-10-31 16:15:21 +05:30
parent bedfec6bf9
commit 921f745435
5 changed files with 171 additions and 22 deletions

View File

@@ -1,12 +1,10 @@
package io.github.ollama4j;
import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.exceptions.RoleNotFoundException;
import io.github.ollama4j.exceptions.ToolInvocationException;
import io.github.ollama4j.exceptions.ToolNotFoundException;
import io.github.ollama4j.models.chat.OllamaChatMessage;
import io.github.ollama4j.models.chat.OllamaChatRequest;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
@@ -166,7 +164,6 @@ public class OllamaAPI {
}
}
/**
* Pull a model on the Ollama server from the list of <a
* href="https://ollama.ai/library">available models</a>.
@@ -471,7 +468,6 @@ public class OllamaAPI {
return toolResult;
}
/**
* Generate response for a question to a model running on Ollama server and get a callback handle
* that can be used to check for status and get the response from the model later. This would be
@@ -570,7 +566,6 @@ public class OllamaAPI {
return generateWithImageURLs(model, prompt, imageURLs, options, null);
}
/**
* Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api
* 'api/chat'.
@@ -639,6 +634,25 @@ public class OllamaAPI {
toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
}
/**
* @param roleName - Custom role to be added
* @return OllamaChatMessageRole
*/
public OllamaChatMessageRole addCustomRole(String roleName) {
return OllamaChatMessageRole.newCustomRole(roleName);
}
/**
* @return - available roles
*/
public List<OllamaChatMessageRole> listRoles() {
return OllamaChatMessageRole.getRoles();
}
public OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
return OllamaChatMessageRole.getRole(roleName);
}
// technical private methods //
private static String encodeFileToBase64(File file) throws IOException {
@@ -694,7 +708,6 @@ public class OllamaAPI {
return basicAuth != null;
}
private Object invokeTool(ToolFunctionCallSpec toolFunctionCallSpec) throws ToolInvocationException {
try {
String methodName = toolFunctionCallSpec.getName();

View File

@@ -0,0 +1,8 @@
package io.github.ollama4j.exceptions;
public class RoleNotFoundException extends Exception {
public RoleNotFoundException(String s) {
super(s);
}
}

View File

@@ -1,20 +1,53 @@
package io.github.ollama4j.models.chat;
import com.fasterxml.jackson.annotation.JsonValue;
import io.github.ollama4j.exceptions.RoleNotFoundException;
import lombok.Getter;
import java.util.ArrayList;
import java.util.List;
/**
* Defines the possible Chat Message roles.
*/
public enum OllamaChatMessageRole {
SYSTEM("system"),
USER("user"),
ASSISTANT("assistant"),
TOOL("tool");
@Getter
public class OllamaChatMessageRole {
private static final List<OllamaChatMessageRole> ROLES = new ArrayList<>();
public static final OllamaChatMessageRole SYSTEM = new OllamaChatMessageRole("system");
public static final OllamaChatMessageRole USER = new OllamaChatMessageRole("user");
public static final OllamaChatMessageRole ASSISTANT = new OllamaChatMessageRole("assistant");
public static final OllamaChatMessageRole TOOL = new OllamaChatMessageRole("tool");
@JsonValue
private String roleName;
private final String roleName;
private OllamaChatMessageRole(String roleName){
private OllamaChatMessageRole(String roleName) {
this.roleName = roleName;
ROLES.add(this);
}
public static OllamaChatMessageRole newCustomRole(String roleName) {
OllamaChatMessageRole customRole = new OllamaChatMessageRole(roleName);
ROLES.add(customRole);
return customRole;
}
public static List<OllamaChatMessageRole> getRoles() {
return new ArrayList<>(ROLES);
}
public static OllamaChatMessageRole getRole(String roleName) throws RoleNotFoundException {
for (OllamaChatMessageRole role : ROLES) {
if (role.roleName.equals(roleName)) {
return role;
}
}
throw new RoleNotFoundException("Invalid role name: " + roleName);
}
@Override
public String toString() {
return roleName;
}
}

View File

@@ -8,12 +8,11 @@ import io.github.ollama4j.models.response.OllamaResult;
* Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
* {@link OllamaChatMessageRole#ASSISTANT} role.
*/
public class OllamaChatResult extends OllamaResult{
public class OllamaChatResult extends OllamaResult {
private List<OllamaChatMessage> chatHistory;
public OllamaChatResult(String response, long responseTime, int httpStatusCode,
List<OllamaChatMessage> chatHistory) {
public OllamaChatResult(String response, long responseTime, int httpStatusCode, List<OllamaChatMessage> chatHistory) {
super(response, responseTime, httpStatusCode);
this.chatHistory = chatHistory;
appendAnswerToChatHistory(response);
@@ -21,12 +20,10 @@ public class OllamaChatResult extends OllamaResult{
public List<OllamaChatMessage> getChatHistory() {
return chatHistory;
}
}
private void appendAnswerToChatHistory(String answer){
private void appendAnswerToChatHistory(String answer) {
OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer);
this.chatHistory.add(assistantMessage);
}
}