mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 03:47:13 +02:00
This commit is contained in:
parent
9a12cebb68
commit
e409ff1cf9
@ -610,9 +610,9 @@ public class OllamaAPI {
|
||||
OllamaToolsResult toolResult = new OllamaToolsResult();
|
||||
Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
|
||||
|
||||
if(!prompt.startsWith("[AVAILABLE_TOOLS]")){
|
||||
if (!prompt.startsWith("[AVAILABLE_TOOLS]")) {
|
||||
final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder();
|
||||
for(Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) {
|
||||
for (Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) {
|
||||
promptBuilder.withToolSpecification(spec);
|
||||
}
|
||||
promptBuilder.withPrompt(prompt);
|
||||
@ -794,8 +794,8 @@ public class OllamaAPI {
|
||||
* <p>
|
||||
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
|
||||
*
|
||||
* @param request request object to be sent to the server
|
||||
* @param tokenHandler callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated)
|
||||
* @param request request object to be sent to the server
|
||||
* @param tokenHandler callback handler to handle the last token from stream (caution: all previous messages from stream will be concatenated)
|
||||
* @return {@link OllamaChatResult}
|
||||
* @throws OllamaBaseException any response code than 200 has been returned
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
@ -821,13 +821,13 @@ public class OllamaAPI {
|
||||
// check if toolCallIsWanted
|
||||
List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||
int toolCallTries = 0;
|
||||
while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries){
|
||||
for (OllamaChatToolCalls toolCall : toolCalls){
|
||||
while (toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries) {
|
||||
for (OllamaChatToolCalls toolCall : toolCalls) {
|
||||
String toolName = toolCall.getFunction().getName();
|
||||
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
|
||||
Map<String, Object> arguments = toolCall.getFunction().getArguments();
|
||||
Object res = toolFunction.apply(arguments);
|
||||
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
|
||||
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]"));
|
||||
}
|
||||
|
||||
if (tokenHandler != null) {
|
||||
@ -842,12 +842,38 @@ public class OllamaAPI {
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers a single tool in the tool registry using the provided tool specification.
|
||||
*
|
||||
* @param toolSpecification the specification of the tool to register. It contains the
|
||||
* tool's function name and other relevant information.
|
||||
*/
|
||||
public void registerTool(Tools.ToolSpecification toolSpecification) {
|
||||
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers multiple tools in the tool registry using a list of tool specifications.
|
||||
* Iterates over the list and adds each tool specification to the registry.
|
||||
*
|
||||
* @param toolSpecifications a list of tool specifications to register. Each specification
|
||||
* contains information about a tool, such as its function name.
|
||||
*/
|
||||
public void registerTools(List<Tools.ToolSpecification> toolSpecifications) {
|
||||
for (Tools.ToolSpecification toolSpecification : toolSpecifications) {
|
||||
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
|
||||
}
|
||||
}
|
||||
|
||||
public void registerAnnotatedTools() {
|
||||
/**
|
||||
* Registers tools based on the annotations found on the methods of the caller's class and its providers.
|
||||
* This method scans the caller's class for the {@link OllamaToolService} annotation and recursively registers
|
||||
* annotated tools from all the providers specified in the annotation.
|
||||
*
|
||||
* @throws IllegalStateException if the caller's class is not annotated with {@link OllamaToolService}.
|
||||
* @throws RuntimeException if any reflection-based instantiation or invocation fails.
|
||||
*/
|
||||
public void registerAnnotatedTools() {
|
||||
try {
|
||||
Class<?> callerClass = null;
|
||||
try {
|
||||
@ -865,70 +891,52 @@ public class OllamaAPI {
|
||||
for (Class<?> provider : providers) {
|
||||
registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
|
||||
}
|
||||
} catch (InstantiationException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
|
||||
} catch (InstantiationException | NoSuchMethodException | IllegalAccessException |
|
||||
InvocationTargetException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public void registerAnnotatedTools(Object object) {
|
||||
/**
|
||||
* Registers tools based on the annotations found on the methods of the provided object.
|
||||
* This method scans the methods of the given object and registers tools using the {@link ToolSpec} annotation
|
||||
* and associated {@link ToolProperty} annotations. It constructs tool specifications and stores them in a tool registry.
|
||||
*
|
||||
* @param object the object whose methods are to be inspected for annotated tools.
|
||||
* @throws RuntimeException if any reflection-based instantiation or invocation fails.
|
||||
*/
|
||||
public void registerAnnotatedTools(Object object) {
|
||||
Class<?> objectClass = object.getClass();
|
||||
Method[] methods = objectClass.getMethods();
|
||||
for(Method m : methods) {
|
||||
for (Method m : methods) {
|
||||
ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
|
||||
if(toolSpec == null){
|
||||
if (toolSpec == null) {
|
||||
continue;
|
||||
}
|
||||
String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
|
||||
String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
|
||||
|
||||
final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
|
||||
LinkedHashMap<String,String> methodParams = new LinkedHashMap<>();
|
||||
LinkedHashMap<String, String> methodParams = new LinkedHashMap<>();
|
||||
for (Parameter parameter : m.getParameters()) {
|
||||
final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class);
|
||||
String propType = parameter.getType().getTypeName();
|
||||
if(toolPropertyAnn == null) {
|
||||
methodParams.put(parameter.getName(),null);
|
||||
if (toolPropertyAnn == null) {
|
||||
methodParams.put(parameter.getName(), null);
|
||||
continue;
|
||||
}
|
||||
String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
|
||||
methodParams.put(propName,propType);
|
||||
propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder()
|
||||
.type(propType)
|
||||
.description(toolPropertyAnn.desc())
|
||||
.required(toolPropertyAnn.required())
|
||||
.build());
|
||||
methodParams.put(propName, propType);
|
||||
propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType).description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build());
|
||||
}
|
||||
final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
|
||||
List<String> reqProps = params.entrySet().stream()
|
||||
.filter(e -> e.getValue().isRequired())
|
||||
.map(Map.Entry::getKey)
|
||||
.collect(Collectors.toList());
|
||||
List<String> reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()).map(Map.Entry::getKey).collect(Collectors.toList());
|
||||
|
||||
Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder()
|
||||
.functionName(operationName)
|
||||
.functionDescription(operationDesc)
|
||||
.toolPrompt(
|
||||
Tools.PromptFuncDefinition.builder().type("function").function(
|
||||
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||
.name(operationName)
|
||||
.description(operationDesc)
|
||||
.parameters(
|
||||
Tools.PromptFuncDefinition.Parameters.builder()
|
||||
.type("object")
|
||||
.properties(
|
||||
params
|
||||
)
|
||||
.required(reqProps)
|
||||
.build()
|
||||
).build()
|
||||
).build()
|
||||
)
|
||||
.build();
|
||||
Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName).functionDescription(operationDesc).toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName).description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(params).required(reqProps).build()).build()).build()).build();
|
||||
|
||||
ReflectionalToolFunction reflectionalToolFunction =
|
||||
new ReflectionalToolFunction(object, m, methodParams);
|
||||
ReflectionalToolFunction reflectionalToolFunction = new ReflectionalToolFunction(object, m, methodParams);
|
||||
toolSpecification.setToolFunction(reflectionalToolFunction);
|
||||
toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification);
|
||||
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
|
||||
}
|
||||
|
||||
}
|
||||
@ -966,14 +974,39 @@ public class OllamaAPI {
|
||||
|
||||
// technical private methods //
|
||||
|
||||
/**
|
||||
* Utility method to encode a file into a Base64 encoded string.
|
||||
*
|
||||
* @param file the file to be encoded into Base64.
|
||||
* @return a Base64 encoded string representing the contents of the file.
|
||||
* @throws IOException if an I/O error occurs during reading the file.
|
||||
*/
|
||||
private static String encodeFileToBase64(File file) throws IOException {
|
||||
return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility method to encode a byte array into a Base64 encoded string.
|
||||
*
|
||||
* @param bytes the byte array to be encoded into Base64.
|
||||
* @return a Base64 encoded string representing the byte array.
|
||||
*/
|
||||
private static String encodeByteArrayToBase64(byte[] bytes) {
|
||||
return Base64.getEncoder().encodeToString(bytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a request for the Ollama API and returns the result.
|
||||
* This method synchronously calls the Ollama API. If a stream handler is provided,
|
||||
* the request will be streamed; otherwise, a regular synchronous request will be made.
|
||||
*
|
||||
* @param ollamaRequestModel the request model containing necessary parameters for the Ollama API request.
|
||||
* @param streamHandler the stream handler to process streaming responses, or null for non-streaming requests.
|
||||
* @return the result of the Ollama API request.
|
||||
* @throws OllamaBaseException if the request fails due to an issue with the Ollama API.
|
||||
* @throws IOException if an I/O error occurs during the request process.
|
||||
* @throws InterruptedException if the thread is interrupted during the request.
|
||||
*/
|
||||
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||
OllamaResult result;
|
||||
@ -986,6 +1019,7 @@ public class OllamaAPI {
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get default request builder.
|
||||
*
|
||||
|
Loading…
x
Reference in New Issue
Block a user