diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 77d6e62..790478b 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -610,9 +610,9 @@ public class OllamaAPI { OllamaToolsResult toolResult = new OllamaToolsResult(); Map toolResults = new HashMap<>(); - if(!prompt.startsWith("[AVAILABLE_TOOLS]")){ + if (!prompt.startsWith("[AVAILABLE_TOOLS]")) { final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder(); - for(Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) { + for (Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) { promptBuilder.withToolSpecification(spec); } promptBuilder.withPrompt(prompt); @@ -794,8 +794,8 @@ public class OllamaAPI { *

* Hint: the OllamaChatRequestModel#getStream() property is not implemented. * - * @param request request object to be sent to the server - * @param tokenHandler callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated) + * @param request request object to be sent to the server + * @param tokenHandler callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated) * @return {@link OllamaChatResult} * @throws OllamaBaseException any response code than 200 has been returned * @throws IOException in case the responseStream can not be read @@ -821,13 +821,13 @@ public class OllamaAPI { // check if toolCallIsWanted List toolCalls = result.getResponseModel().getMessage().getToolCalls(); int toolCallTries = 0; - while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries){ - for (OllamaChatToolCalls toolCall : toolCalls){ + while (toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries) { + for (OllamaChatToolCalls toolCall : toolCalls) { String toolName = toolCall.getFunction().getName(); ToolFunction toolFunction = toolRegistry.getToolFunction(toolName); Map arguments = toolCall.getFunction().getArguments(); Object res = toolFunction.apply(arguments); - request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]")); + request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]")); } if (tokenHandler != null) { @@ -842,12 +842,38 @@ public class OllamaAPI { return result; } + /** + * Registers a single tool in the tool registry using the provided tool specification. + * + * @param toolSpecification the specification of the tool to register. It contains the + * tool's function name and other relevant information. + */ public void registerTool(Tools.ToolSpecification toolSpecification) { toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); } + /** + * Registers multiple tools in the tool registry using a list of tool specifications. + * Iterates over the list and adds each tool specification to the registry. + * + * @param toolSpecifications a list of tool specifications to register. Each specification + * contains information about a tool, such as its function name. + */ + public void registerTools(List toolSpecifications) { + for (Tools.ToolSpecification toolSpecification : toolSpecifications) { + toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); + } + } - public void registerAnnotatedTools() { + /** + * Registers tools based on the annotations found on the methods of the caller's class and its providers. + * This method scans the caller's class for the {@link OllamaToolService} annotation and recursively registers + * annotated tools from all the providers specified in the annotation. + * + * @throws IllegalStateException if the caller's class is not annotated with {@link OllamaToolService}. + * @throws RuntimeException if any reflection-based instantiation or invocation fails. + */ + public void registerAnnotatedTools() { try { Class callerClass = null; try { @@ -865,70 +891,52 @@ public class OllamaAPI { for (Class provider : providers) { registerAnnotatedTools(provider.getDeclaredConstructor().newInstance()); } - } catch (InstantiationException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + } catch (InstantiationException | NoSuchMethodException | IllegalAccessException | + InvocationTargetException e) { throw new RuntimeException(e); } } - public void registerAnnotatedTools(Object object) { + /** + * Registers tools based on the annotations found on the methods of the provided object. + * This method scans the methods of the given object and registers tools using the {@link ToolSpec} annotation + * and associated {@link ToolProperty} annotations. It constructs tool specifications and stores them in a tool registry. + * + * @param object the object whose methods are to be inspected for annotated tools. + * @throws RuntimeException if any reflection-based instantiation or invocation fails. + */ + public void registerAnnotatedTools(Object object) { Class objectClass = object.getClass(); Method[] methods = objectClass.getMethods(); - for(Method m : methods) { + for (Method m : methods) { ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class); - if(toolSpec == null){ + if (toolSpec == null) { continue; } String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName(); String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName; final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder(); - LinkedHashMap methodParams = new LinkedHashMap<>(); + LinkedHashMap methodParams = new LinkedHashMap<>(); for (Parameter parameter : m.getParameters()) { final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class); String propType = parameter.getType().getTypeName(); - if(toolPropertyAnn == null) { - methodParams.put(parameter.getName(),null); + if (toolPropertyAnn == null) { + methodParams.put(parameter.getName(), null); continue; } String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName(); - methodParams.put(propName,propType); - propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder() - .type(propType) - .description(toolPropertyAnn.desc()) - .required(toolPropertyAnn.required()) - .build()); + methodParams.put(propName, propType); + propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType).description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build()); } final Map params = propsBuilder.build(); - List reqProps = params.entrySet().stream() - .filter(e -> e.getValue().isRequired()) - .map(Map.Entry::getKey) - .collect(Collectors.toList()); + List reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()).map(Map.Entry::getKey).collect(Collectors.toList()); - Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder() - .functionName(operationName) - .functionDescription(operationDesc) - .toolPrompt( - Tools.PromptFuncDefinition.builder().type("function").function( - Tools.PromptFuncDefinition.PromptFuncSpec.builder() - .name(operationName) - .description(operationDesc) - .parameters( - Tools.PromptFuncDefinition.Parameters.builder() - .type("object") - .properties( - params - ) - .required(reqProps) - .build() - ).build() - ).build() - ) - .build(); + Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName).functionDescription(operationDesc).toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName).description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(params).required(reqProps).build()).build()).build()).build(); - ReflectionalToolFunction reflectionalToolFunction = - new ReflectionalToolFunction(object, m, methodParams); + ReflectionalToolFunction reflectionalToolFunction = new ReflectionalToolFunction(object, m, methodParams); toolSpecification.setToolFunction(reflectionalToolFunction); - toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification); + toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); } } @@ -966,14 +974,39 @@ public class OllamaAPI { // technical private methods // + /** + * Utility method to encode a file into a Base64 encoded string. + * + * @param file the file to be encoded into Base64. + * @return a Base64 encoded string representing the contents of the file. + * @throws IOException if an I/O error occurs during reading the file. + */ private static String encodeFileToBase64(File file) throws IOException { return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath())); } + /** + * Utility method to encode a byte array into a Base64 encoded string. + * + * @param bytes the byte array to be encoded into Base64. + * @return a Base64 encoded string representing the byte array. + */ private static String encodeByteArrayToBase64(byte[] bytes) { return Base64.getEncoder().encodeToString(bytes); } + /** + * Generates a request for the Ollama API and returns the result. + * This method synchronously calls the Ollama API. If a stream handler is provided, + * the request will be streamed; otherwise, a regular synchronous request will be made. + * + * @param ollamaRequestModel the request model containing necessary parameters for the Ollama API request. + * @param streamHandler the stream handler to process streaming responses, or null for non-streaming requests. + * @return the result of the Ollama API request. + * @throws OllamaBaseException if the request fails due to an issue with the Ollama API. + * @throws IOException if an I/O error occurs during the request process. + * @throws InterruptedException if the thread is interrupted during the request. + */ private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaResult result; @@ -986,6 +1019,7 @@ public class OllamaAPI { return result; } + /** * Get default request builder. *