Extends ChatModels to use Tools and ToolCalls

This commit is contained in:
Markus Klenke
2024-12-06 14:12:33 +01:00
committed by Markus Klenke
parent e9c33ab0b2
commit 12bb10392e
7 changed files with 95 additions and 20 deletions

View File

@@ -602,7 +602,7 @@ public class OllamaAPI {
OllamaResult result = generate(model, prompt, raw, options, null);
toolResult.setModelResult(result);
String toolsResponse = result.getResponse();
String toolsResponse = result.getContent();
if (toolsResponse.contains("[TOOL_CALLS]")) {
toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
}
@@ -768,6 +768,10 @@ public class OllamaAPI {
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
OllamaResult result;
// add all registered tools to Request
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
if (streamHandler != null) {
request.setStream(true);
result = requestCaller.call(request, streamHandler);
@@ -775,10 +779,7 @@ public class OllamaAPI {
result = requestCaller.callSync(request);
}
// add all registered tools to Request
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
return new OllamaChatResult(result.getContent(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
}
public void registerTool(Tools.ToolSpecification toolSpecification) {

View File

@@ -2,6 +2,7 @@ package io.github.ollama4j.models.chat;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
@@ -32,6 +33,8 @@ public class OllamaChatMessage {
@NonNull
private String content;
private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
@JsonSerialize(using = FileToBase64Serializer.class)
private List<byte[]> images;

View File

@@ -38,7 +38,7 @@ public class OllamaChatRequestBuilder {
request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
}
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images) {
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = images.stream().map(file -> {
@@ -50,11 +50,11 @@ public class OllamaChatRequestBuilder {
}
}).collect(Collectors.toList());
messages.add(new OllamaChatMessage(role, content, binaryImages));
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
return this;
}
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) {
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List<OllamaChatToolCalls> toolCalls, String... imageUrls) {
List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = null;
if (imageUrls.length > 0) {
@@ -70,7 +70,7 @@ public class OllamaChatRequestBuilder {
}
}
messages.add(new OllamaChatMessage(role, content, binaryImages));
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
return this;
}

View File

@@ -0,0 +1,16 @@
package io.github.ollama4j.models.chat;
import io.github.ollama4j.tools.OllamaToolCallsFunction;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class OllamaChatToolCalls {
private OllamaToolCallsFunction function;
}

View File

@@ -17,7 +17,7 @@ public class OllamaResult {
*
* @return String completion/response text
*/
private final String response;
private final String content;
/**
* -- GETTER --
@@ -35,8 +35,8 @@ public class OllamaResult {
*/
private long responseTime = 0;
public OllamaResult(String response, long responseTime, int httpStatusCode) {
this.response = response;
public OllamaResult(String content, long responseTime, int httpStatusCode) {
this.content = content;
this.responseTime = responseTime;
this.httpStatusCode = httpStatusCode;
}

View File

@@ -0,0 +1,16 @@
package io.github.ollama4j.tools;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Map;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class OllamaToolCallsFunction
{
private String name;
private Map<String,String> arguments;
}