forked from Mirror/ollama4j
		
	Update OllamaAPI.java
This commit is contained in:
		@@ -610,9 +610,9 @@ public class OllamaAPI {
 | 
			
		||||
        OllamaToolsResult toolResult = new OllamaToolsResult();
 | 
			
		||||
        Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
 | 
			
		||||
 | 
			
		||||
        if(!prompt.startsWith("[AVAILABLE_TOOLS]")){
 | 
			
		||||
        if (!prompt.startsWith("[AVAILABLE_TOOLS]")) {
 | 
			
		||||
            final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder();
 | 
			
		||||
            for(Tools.ToolSpecification spec :  toolRegistry.getRegisteredSpecs()) {
 | 
			
		||||
            for (Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) {
 | 
			
		||||
                promptBuilder.withToolSpecification(spec);
 | 
			
		||||
            }
 | 
			
		||||
            promptBuilder.withPrompt(prompt);
 | 
			
		||||
@@ -794,8 +794,8 @@ public class OllamaAPI {
 | 
			
		||||
     * <p>
 | 
			
		||||
     * Hint: the OllamaChatRequestModel#getStream() property is not implemented.
 | 
			
		||||
     *
 | 
			
		||||
     * @param request       request object to be sent to the server
 | 
			
		||||
     * @param tokenHandler  callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated)
 | 
			
		||||
     * @param request      request object to be sent to the server
 | 
			
		||||
     * @param tokenHandler callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated)
 | 
			
		||||
     * @return {@link OllamaChatResult}
 | 
			
		||||
     * @throws OllamaBaseException  any response code than 200 has been returned
 | 
			
		||||
     * @throws IOException          in case the responseStream can not be read
 | 
			
		||||
@@ -821,13 +821,13 @@ public class OllamaAPI {
 | 
			
		||||
        // check if toolCallIsWanted
 | 
			
		||||
        List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
 | 
			
		||||
        int toolCallTries = 0;
 | 
			
		||||
        while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries){
 | 
			
		||||
            for (OllamaChatToolCalls toolCall : toolCalls){
 | 
			
		||||
        while (toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries) {
 | 
			
		||||
            for (OllamaChatToolCalls toolCall : toolCalls) {
 | 
			
		||||
                String toolName = toolCall.getFunction().getName();
 | 
			
		||||
                ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
 | 
			
		||||
                Map<String, Object> arguments = toolCall.getFunction().getArguments();
 | 
			
		||||
                Object res = toolFunction.apply(arguments);
 | 
			
		||||
                request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
 | 
			
		||||
                request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]"));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if (tokenHandler != null) {
 | 
			
		||||
@@ -842,12 +842,38 @@ public class OllamaAPI {
 | 
			
		||||
        return result;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Registers a single tool in the tool registry using the provided tool specification.
 | 
			
		||||
     *
 | 
			
		||||
     * @param toolSpecification the specification of the tool to register. It contains the
 | 
			
		||||
     *                          tool's function name and other relevant information.
 | 
			
		||||
     */
 | 
			
		||||
    public void registerTool(Tools.ToolSpecification toolSpecification) {
 | 
			
		||||
        toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Registers multiple tools in the tool registry using a list of tool specifications.
 | 
			
		||||
     * Iterates over the list and adds each tool specification to the registry.
 | 
			
		||||
     *
 | 
			
		||||
     * @param toolSpecifications a list of tool specifications to register. Each specification
 | 
			
		||||
     *                           contains information about a tool, such as its function name.
 | 
			
		||||
     */
 | 
			
		||||
    public void registerTools(List<Tools.ToolSpecification> toolSpecifications) {
 | 
			
		||||
        for (Tools.ToolSpecification toolSpecification : toolSpecifications) {
 | 
			
		||||
            toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public void registerAnnotatedTools()  {
 | 
			
		||||
    /**
 | 
			
		||||
     * Registers tools based on the annotations found on the methods of the caller's class and its providers.
 | 
			
		||||
     * This method scans the caller's class for the {@link OllamaToolService} annotation and recursively registers
 | 
			
		||||
     * annotated tools from all the providers specified in the annotation.
 | 
			
		||||
     *
 | 
			
		||||
     * @throws IllegalStateException if the caller's class is not annotated with {@link OllamaToolService}.
 | 
			
		||||
     * @throws RuntimeException      if any reflection-based instantiation or invocation fails.
 | 
			
		||||
     */
 | 
			
		||||
    public void registerAnnotatedTools() {
 | 
			
		||||
        try {
 | 
			
		||||
            Class<?> callerClass = null;
 | 
			
		||||
            try {
 | 
			
		||||
@@ -865,70 +891,52 @@ public class OllamaAPI {
 | 
			
		||||
            for (Class<?> provider : providers) {
 | 
			
		||||
                registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
 | 
			
		||||
            }
 | 
			
		||||
        } catch (InstantiationException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
 | 
			
		||||
        } catch (InstantiationException | NoSuchMethodException | IllegalAccessException |
 | 
			
		||||
                 InvocationTargetException e) {
 | 
			
		||||
            throw new RuntimeException(e);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public void registerAnnotatedTools(Object object)  {
 | 
			
		||||
    /**
 | 
			
		||||
     * Registers tools based on the annotations found on the methods of the provided object.
 | 
			
		||||
     * This method scans the methods of the given object and registers tools using the {@link ToolSpec} annotation
 | 
			
		||||
     * and associated {@link ToolProperty} annotations. It constructs tool specifications and stores them in a tool registry.
 | 
			
		||||
     *
 | 
			
		||||
     * @param object the object whose methods are to be inspected for annotated tools.
 | 
			
		||||
     * @throws RuntimeException if any reflection-based instantiation or invocation fails.
 | 
			
		||||
     */
 | 
			
		||||
    public void registerAnnotatedTools(Object object) {
 | 
			
		||||
        Class<?> objectClass = object.getClass();
 | 
			
		||||
        Method[] methods = objectClass.getMethods();
 | 
			
		||||
        for(Method m : methods) {
 | 
			
		||||
        for (Method m : methods) {
 | 
			
		||||
            ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
 | 
			
		||||
            if(toolSpec == null){
 | 
			
		||||
            if (toolSpec == null) {
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
            String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
 | 
			
		||||
            String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
 | 
			
		||||
 | 
			
		||||
            final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
 | 
			
		||||
            LinkedHashMap<String,String> methodParams = new LinkedHashMap<>();
 | 
			
		||||
            LinkedHashMap<String, String> methodParams = new LinkedHashMap<>();
 | 
			
		||||
            for (Parameter parameter : m.getParameters()) {
 | 
			
		||||
                final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class);
 | 
			
		||||
                String propType = parameter.getType().getTypeName();
 | 
			
		||||
                if(toolPropertyAnn == null) {
 | 
			
		||||
                    methodParams.put(parameter.getName(),null);
 | 
			
		||||
                if (toolPropertyAnn == null) {
 | 
			
		||||
                    methodParams.put(parameter.getName(), null);
 | 
			
		||||
                    continue;
 | 
			
		||||
                }
 | 
			
		||||
                String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
 | 
			
		||||
                methodParams.put(propName,propType);
 | 
			
		||||
                propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder()
 | 
			
		||||
                        .type(propType)
 | 
			
		||||
                        .description(toolPropertyAnn.desc())
 | 
			
		||||
                        .required(toolPropertyAnn.required())
 | 
			
		||||
                        .build());
 | 
			
		||||
                methodParams.put(propName, propType);
 | 
			
		||||
                propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType).description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build());
 | 
			
		||||
            }
 | 
			
		||||
            final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
 | 
			
		||||
            List<String> reqProps = params.entrySet().stream()
 | 
			
		||||
                    .filter(e -> e.getValue().isRequired())
 | 
			
		||||
                    .map(Map.Entry::getKey)
 | 
			
		||||
                    .collect(Collectors.toList());
 | 
			
		||||
            List<String> reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()).map(Map.Entry::getKey).collect(Collectors.toList());
 | 
			
		||||
 | 
			
		||||
            Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder()
 | 
			
		||||
                    .functionName(operationName)
 | 
			
		||||
                    .functionDescription(operationDesc)
 | 
			
		||||
                    .toolPrompt(
 | 
			
		||||
                            Tools.PromptFuncDefinition.builder().type("function").function(
 | 
			
		||||
                                    Tools.PromptFuncDefinition.PromptFuncSpec.builder()
 | 
			
		||||
                                            .name(operationName)
 | 
			
		||||
                                            .description(operationDesc)
 | 
			
		||||
                                            .parameters(
 | 
			
		||||
                                                    Tools.PromptFuncDefinition.Parameters.builder()
 | 
			
		||||
                                                            .type("object")
 | 
			
		||||
                                                            .properties(
 | 
			
		||||
                                                                    params
 | 
			
		||||
                                                            )
 | 
			
		||||
                                                            .required(reqProps)
 | 
			
		||||
                                                            .build()
 | 
			
		||||
                                            ).build()
 | 
			
		||||
                            ).build()
 | 
			
		||||
                    )
 | 
			
		||||
                    .build();
 | 
			
		||||
            Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName).functionDescription(operationDesc).toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName).description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(params).required(reqProps).build()).build()).build()).build();
 | 
			
		||||
 | 
			
		||||
            ReflectionalToolFunction reflectionalToolFunction =
 | 
			
		||||
                    new ReflectionalToolFunction(object, m, methodParams);
 | 
			
		||||
            ReflectionalToolFunction reflectionalToolFunction = new ReflectionalToolFunction(object, m, methodParams);
 | 
			
		||||
            toolSpecification.setToolFunction(reflectionalToolFunction);
 | 
			
		||||
            toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification);
 | 
			
		||||
            toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
@@ -966,14 +974,39 @@ public class OllamaAPI {
 | 
			
		||||
 | 
			
		||||
    // technical private methods //
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Utility method to encode a file into a Base64 encoded string.
 | 
			
		||||
     *
 | 
			
		||||
     * @param file the file to be encoded into Base64.
 | 
			
		||||
     * @return a Base64 encoded string representing the contents of the file.
 | 
			
		||||
     * @throws IOException if an I/O error occurs during reading the file.
 | 
			
		||||
     */
 | 
			
		||||
    private static String encodeFileToBase64(File file) throws IOException {
 | 
			
		||||
        return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Utility method to encode a byte array into a Base64 encoded string.
 | 
			
		||||
     *
 | 
			
		||||
     * @param bytes the byte array to be encoded into Base64.
 | 
			
		||||
     * @return a Base64 encoded string representing the byte array.
 | 
			
		||||
     */
 | 
			
		||||
    private static String encodeByteArrayToBase64(byte[] bytes) {
 | 
			
		||||
        return Base64.getEncoder().encodeToString(bytes);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Generates a request for the Ollama API and returns the result.
 | 
			
		||||
     * This method synchronously calls the Ollama API. If a stream handler is provided,
 | 
			
		||||
     * the request will be streamed; otherwise, a regular synchronous request will be made.
 | 
			
		||||
     *
 | 
			
		||||
     * @param ollamaRequestModel the request model containing necessary parameters for the Ollama API request.
 | 
			
		||||
     * @param streamHandler      the stream handler to process streaming responses, or null for non-streaming requests.
 | 
			
		||||
     * @return the result of the Ollama API request.
 | 
			
		||||
     * @throws OllamaBaseException  if the request fails due to an issue with the Ollama API.
 | 
			
		||||
     * @throws IOException          if an I/O error occurs during the request process.
 | 
			
		||||
     * @throws InterruptedException if the thread is interrupted during the request.
 | 
			
		||||
     */
 | 
			
		||||
    private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
 | 
			
		||||
        OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
 | 
			
		||||
        OllamaResult result;
 | 
			
		||||
@@ -986,6 +1019,7 @@ public class OllamaAPI {
 | 
			
		||||
        return result;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get default request builder.
 | 
			
		||||
     *
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user