Extends ChatModels to use Tools and ToolCalls

This commit is contained in:
Markus Klenke 2024-12-06 14:12:33 +01:00 committed by Markus Klenke
parent e9c33ab0b2
commit 12bb10392e
7 changed files with 95 additions and 20 deletions

View File

@ -602,7 +602,7 @@ public class OllamaAPI {
OllamaResult result = generate(model, prompt, raw, options, null); OllamaResult result = generate(model, prompt, raw, options, null);
toolResult.setModelResult(result); toolResult.setModelResult(result);
String toolsResponse = result.getResponse(); String toolsResponse = result.getContent();
if (toolsResponse.contains("[TOOL_CALLS]")) { if (toolsResponse.contains("[TOOL_CALLS]")) {
toolsResponse = toolsResponse.replace("[TOOL_CALLS]", ""); toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
} }
@ -768,6 +768,10 @@ public class OllamaAPI {
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
OllamaResult result; OllamaResult result;
// add all registered tools to Request
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
if (streamHandler != null) { if (streamHandler != null) {
request.setStream(true); request.setStream(true);
result = requestCaller.call(request, streamHandler); result = requestCaller.call(request, streamHandler);
@ -775,10 +779,7 @@ public class OllamaAPI {
result = requestCaller.callSync(request); result = requestCaller.callSync(request);
} }
// add all registered tools to Request return new OllamaChatResult(result.getContent(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
} }
public void registerTool(Tools.ToolSpecification toolSpecification) { public void registerTool(Tools.ToolSpecification toolSpecification) {

View File

@ -2,6 +2,7 @@ package io.github.ollama4j.models.chat;
import static io.github.ollama4j.utils.Utils.getObjectMapper; import static io.github.ollama4j.utils.Utils.getObjectMapper;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
@ -32,6 +33,8 @@ public class OllamaChatMessage {
@NonNull @NonNull
private String content; private String content;
private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
@JsonSerialize(using = FileToBase64Serializer.class) @JsonSerialize(using = FileToBase64Serializer.class)
private List<byte[]> images; private List<byte[]> images;

View File

@ -38,7 +38,7 @@ public class OllamaChatRequestBuilder {
request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images) { public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = images.stream().map(file -> { List<byte[]> binaryImages = images.stream().map(file -> {
@ -50,11 +50,11 @@ public class OllamaChatRequestBuilder {
} }
}).collect(Collectors.toList()); }).collect(Collectors.toList());
messages.add(new OllamaChatMessage(role, content, binaryImages)); messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
return this; return this;
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) { public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List<OllamaChatToolCalls> toolCalls, String... imageUrls) {
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
List<byte[]> binaryImages = null; List<byte[]> binaryImages = null;
if (imageUrls.length > 0) { if (imageUrls.length > 0) {
@ -70,7 +70,7 @@ public class OllamaChatRequestBuilder {
} }
} }
messages.add(new OllamaChatMessage(role, content, binaryImages)); messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
return this; return this;
} }

View File

@ -0,0 +1,16 @@
package io.github.ollama4j.models.chat;
import io.github.ollama4j.tools.OllamaToolCallsFunction;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class OllamaChatToolCalls {
private OllamaToolCallsFunction function;
}

View File

@ -17,7 +17,7 @@ public class OllamaResult {
* *
* @return String completion/response text * @return String completion/response text
*/ */
private final String response; private final String content;
/** /**
* -- GETTER -- * -- GETTER --
@ -35,8 +35,8 @@ public class OllamaResult {
*/ */
private long responseTime = 0; private long responseTime = 0;
public OllamaResult(String response, long responseTime, int httpStatusCode) { public OllamaResult(String content, long responseTime, int httpStatusCode) {
this.response = response; this.content = content;
this.responseTime = responseTime; this.responseTime = responseTime;
this.httpStatusCode = httpStatusCode; this.httpStatusCode = httpStatusCode;
} }

View File

@ -0,0 +1,16 @@
package io.github.ollama4j.tools;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Map;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class OllamaToolCallsFunction
{
private String name;
private Map<String,String> arguments;
}

View File

@ -10,6 +10,8 @@ import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import io.github.ollama4j.models.chat.OllamaChatResult; import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.ollama4j.tools.ToolFunction;
import io.github.ollama4j.tools.Tools;
import io.github.ollama4j.utils.OptionsBuilder; import io.github.ollama4j.utils.OptionsBuilder;
import lombok.Data; import lombok.Data;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -24,9 +26,7 @@ import java.io.InputStream;
import java.net.ConnectException; import java.net.ConnectException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.net.http.HttpConnectTimeoutException; import java.net.http.HttpConnectTimeoutException;
import java.util.List; import java.util.*;
import java.util.Objects;
import java.util.Properties;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@ -230,18 +230,47 @@ class TestRealAPIs {
void testChatWithTools() { void testChatWithTools() {
testEndpointReachability(); testEndpointReachability();
try { try {
ollamaAPI.setVerbose(true);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!") final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
.functionName("get-employee-details")
.functionDescription("Get employee details from the database")
.toolPrompt(
Tools.PromptFuncDefinition.builder().type("function").function(
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
.name("get-employee-details")
.description("Get employee details from the database")
.parameters(
Tools.PromptFuncDefinition.Parameters.builder()
.type("object")
.properties(
new Tools.PropsBuilder()
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
.build()
)
.required(List.of("employee-name"))
.build()
).build()
).build()
)
.toolFunction(new DBQueryFunction())
.build();
ollamaAPI.registerTool(databaseQueryToolSpecification);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, .withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?") "Give me the details of the employee named 'Rahul Kumar'?")
.build(); .build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
System.err.println("Response: " + chatResult);
assertNotNull(chatResult); assertNotNull(chatResult);
assertFalse(chatResult.getResponse().isBlank()); assertFalse(chatResult.getResponse().isBlank());
assertTrue(chatResult.getResponse().startsWith("NI")); assertEquals(2, chatResult.getChatHistory().size());
assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e); fail(e);
} }
@ -402,6 +431,14 @@ class TestRealAPIs {
} }
} }
class DBQueryFunction implements ToolFunction {
@Override
public Object apply(Map<String, Object> arguments) {
// perform DB operations here
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString());
}
}
@Data @Data
class Config { class Config {
private String ollamaURL; private String ollamaURL;
@ -426,4 +463,6 @@ class Config {
throw new RuntimeException("Error loading properties", e); throw new RuntimeException("Error loading properties", e);
} }
} }
} }