Update OllamaAPI.java
All checks were successful
Mark stale issues / stale (push) Successful in 14s

This commit is contained in:
amithkoujalgi 2025-02-03 08:56:49 +05:30
parent 9a12cebb68
commit e409ff1cf9
No known key found for this signature in database
GPG Key ID: E29A37746AF94B70

View File

@ -610,9 +610,9 @@ public class OllamaAPI {
OllamaToolsResult toolResult = new OllamaToolsResult(); OllamaToolsResult toolResult = new OllamaToolsResult();
Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>(); Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
if(!prompt.startsWith("[AVAILABLE_TOOLS]")){ if (!prompt.startsWith("[AVAILABLE_TOOLS]")) {
final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder(); final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder();
for(Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) { for (Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) {
promptBuilder.withToolSpecification(spec); promptBuilder.withToolSpecification(spec);
} }
promptBuilder.withPrompt(prompt); promptBuilder.withPrompt(prompt);
@ -821,13 +821,13 @@ public class OllamaAPI {
// check if toolCallIsWanted // check if toolCallIsWanted
List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls(); List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
int toolCallTries = 0; int toolCallTries = 0;
while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries){ while (toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries) {
for (OllamaChatToolCalls toolCall : toolCalls){ for (OllamaChatToolCalls toolCall : toolCalls) {
String toolName = toolCall.getFunction().getName(); String toolName = toolCall.getFunction().getName();
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName); ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
Map<String, Object> arguments = toolCall.getFunction().getArguments(); Map<String, Object> arguments = toolCall.getFunction().getArguments();
Object res = toolFunction.apply(arguments); Object res = toolFunction.apply(arguments);
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]")); request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]"));
} }
if (tokenHandler != null) { if (tokenHandler != null) {
@ -842,11 +842,37 @@ public class OllamaAPI {
return result; return result;
} }
/**
* Registers a single tool in the tool registry using the provided tool specification.
*
* @param toolSpecification the specification of the tool to register. It contains the
* tool's function name and other relevant information.
*/
public void registerTool(Tools.ToolSpecification toolSpecification) { public void registerTool(Tools.ToolSpecification toolSpecification) {
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
} }
/**
* Registers multiple tools in the tool registry using a list of tool specifications.
* Iterates over the list and adds each tool specification to the registry.
*
* @param toolSpecifications a list of tool specifications to register. Each specification
* contains information about a tool, such as its function name.
*/
public void registerTools(List<Tools.ToolSpecification> toolSpecifications) {
for (Tools.ToolSpecification toolSpecification : toolSpecifications) {
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
}
}
/**
* Registers tools based on the annotations found on the methods of the caller's class and its providers.
* This method scans the caller's class for the {@link OllamaToolService} annotation and recursively registers
* annotated tools from all the providers specified in the annotation.
*
* @throws IllegalStateException if the caller's class is not annotated with {@link OllamaToolService}.
* @throws RuntimeException if any reflection-based instantiation or invocation fails.
*/
public void registerAnnotatedTools() { public void registerAnnotatedTools() {
try { try {
Class<?> callerClass = null; Class<?> callerClass = null;
@ -865,70 +891,52 @@ public class OllamaAPI {
for (Class<?> provider : providers) { for (Class<?> provider : providers) {
registerAnnotatedTools(provider.getDeclaredConstructor().newInstance()); registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
} }
} catch (InstantiationException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { } catch (InstantiationException | NoSuchMethodException | IllegalAccessException |
InvocationTargetException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
/**
* Registers tools based on the annotations found on the methods of the provided object.
* This method scans the methods of the given object and registers tools using the {@link ToolSpec} annotation
* and associated {@link ToolProperty} annotations. It constructs tool specifications and stores them in a tool registry.
*
* @param object the object whose methods are to be inspected for annotated tools.
* @throws RuntimeException if any reflection-based instantiation or invocation fails.
*/
public void registerAnnotatedTools(Object object) { public void registerAnnotatedTools(Object object) {
Class<?> objectClass = object.getClass(); Class<?> objectClass = object.getClass();
Method[] methods = objectClass.getMethods(); Method[] methods = objectClass.getMethods();
for(Method m : methods) { for (Method m : methods) {
ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class); ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
if(toolSpec == null){ if (toolSpec == null) {
continue; continue;
} }
String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName(); String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName; String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder(); final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
LinkedHashMap<String,String> methodParams = new LinkedHashMap<>(); LinkedHashMap<String, String> methodParams = new LinkedHashMap<>();
for (Parameter parameter : m.getParameters()) { for (Parameter parameter : m.getParameters()) {
final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class); final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class);
String propType = parameter.getType().getTypeName(); String propType = parameter.getType().getTypeName();
if(toolPropertyAnn == null) { if (toolPropertyAnn == null) {
methodParams.put(parameter.getName(),null); methodParams.put(parameter.getName(), null);
continue; continue;
} }
String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName(); String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
methodParams.put(propName,propType); methodParams.put(propName, propType);
propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder() propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType).description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build());
.type(propType)
.description(toolPropertyAnn.desc())
.required(toolPropertyAnn.required())
.build());
} }
final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build(); final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
List<String> reqProps = params.entrySet().stream() List<String> reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()).map(Map.Entry::getKey).collect(Collectors.toList());
.filter(e -> e.getValue().isRequired())
.map(Map.Entry::getKey)
.collect(Collectors.toList());
Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder() Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName).functionDescription(operationDesc).toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName).description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(params).required(reqProps).build()).build()).build()).build();
.functionName(operationName)
.functionDescription(operationDesc)
.toolPrompt(
Tools.PromptFuncDefinition.builder().type("function").function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name(operationName)
.description(operationDesc)
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
params
)
.required(reqProps)
.build()
).build()
).build()
)
.build();
ReflectionalToolFunction reflectionalToolFunction = ReflectionalToolFunction reflectionalToolFunction = new ReflectionalToolFunction(object, m, methodParams);
new ReflectionalToolFunction(object, m, methodParams);
toolSpecification.setToolFunction(reflectionalToolFunction); toolSpecification.setToolFunction(reflectionalToolFunction);
toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification); toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
} }
} }
@ -966,14 +974,39 @@ public class OllamaAPI {
// technical private methods // // technical private methods //
/**
* Utility method to encode a file into a Base64 encoded string.
*
* @param file the file to be encoded into Base64.
* @return a Base64 encoded string representing the contents of the file.
* @throws IOException if an I/O error occurs during reading the file.
*/
private static String encodeFileToBase64(File file) throws IOException { private static String encodeFileToBase64(File file) throws IOException {
return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath())); return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
} }
/**
* Utility method to encode a byte array into a Base64 encoded string.
*
* @param bytes the byte array to be encoded into Base64.
* @return a Base64 encoded string representing the byte array.
*/
private static String encodeByteArrayToBase64(byte[] bytes) { private static String encodeByteArrayToBase64(byte[] bytes) {
return Base64.getEncoder().encodeToString(bytes); return Base64.getEncoder().encodeToString(bytes);
} }
/**
* Generates a request for the Ollama API and returns the result.
* This method synchronously calls the Ollama API. If a stream handler is provided,
* the request will be streamed; otherwise, a regular synchronous request will be made.
*
* @param ollamaRequestModel the request model containing necessary parameters for the Ollama API request.
* @param streamHandler the stream handler to process streaming responses, or null for non-streaming requests.
* @return the result of the Ollama API request.
* @throws OllamaBaseException if the request fails due to an issue with the Ollama API.
* @throws IOException if an I/O error occurs during the request process.
* @throws InterruptedException if the thread is interrupted during the request.
*/
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
OllamaResult result; OllamaResult result;
@ -986,6 +1019,7 @@ public class OllamaAPI {
return result; return result;
} }
/** /**
* Get default request builder. * Get default request builder.
* *