mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 11:57:12 +02:00
Extends ChatModels to use Tools and ToolCalls
This commit is contained in:
parent
e9c33ab0b2
commit
12bb10392e
@ -602,7 +602,7 @@ public class OllamaAPI {
|
|||||||
OllamaResult result = generate(model, prompt, raw, options, null);
|
OllamaResult result = generate(model, prompt, raw, options, null);
|
||||||
toolResult.setModelResult(result);
|
toolResult.setModelResult(result);
|
||||||
|
|
||||||
String toolsResponse = result.getResponse();
|
String toolsResponse = result.getContent();
|
||||||
if (toolsResponse.contains("[TOOL_CALLS]")) {
|
if (toolsResponse.contains("[TOOL_CALLS]")) {
|
||||||
toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
|
toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
|
||||||
}
|
}
|
||||||
@ -768,6 +768,10 @@ public class OllamaAPI {
|
|||||||
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||||
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||||
OllamaResult result;
|
OllamaResult result;
|
||||||
|
|
||||||
|
// add all registered tools to Request
|
||||||
|
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
|
||||||
|
|
||||||
if (streamHandler != null) {
|
if (streamHandler != null) {
|
||||||
request.setStream(true);
|
request.setStream(true);
|
||||||
result = requestCaller.call(request, streamHandler);
|
result = requestCaller.call(request, streamHandler);
|
||||||
@ -775,10 +779,7 @@ public class OllamaAPI {
|
|||||||
result = requestCaller.callSync(request);
|
result = requestCaller.callSync(request);
|
||||||
}
|
}
|
||||||
|
|
||||||
// add all registered tools to Request
|
return new OllamaChatResult(result.getContent(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
|
||||||
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
|
|
||||||
|
|
||||||
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void registerTool(Tools.ToolSpecification toolSpecification) {
|
public void registerTool(Tools.ToolSpecification toolSpecification) {
|
||||||
|
@ -2,6 +2,7 @@ package io.github.ollama4j.models.chat;
|
|||||||
|
|
||||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||||
|
|
||||||
@ -32,6 +33,8 @@ public class OllamaChatMessage {
|
|||||||
@NonNull
|
@NonNull
|
||||||
private String content;
|
private String content;
|
||||||
|
|
||||||
|
private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
|
||||||
|
|
||||||
@JsonSerialize(using = FileToBase64Serializer.class)
|
@JsonSerialize(using = FileToBase64Serializer.class)
|
||||||
private List<byte[]> images;
|
private List<byte[]> images;
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ public class OllamaChatRequestBuilder {
|
|||||||
request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
|
request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
|
||||||
}
|
}
|
||||||
|
|
||||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images) {
|
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
|
||||||
List<OllamaChatMessage> messages = this.request.getMessages();
|
List<OllamaChatMessage> messages = this.request.getMessages();
|
||||||
|
|
||||||
List<byte[]> binaryImages = images.stream().map(file -> {
|
List<byte[]> binaryImages = images.stream().map(file -> {
|
||||||
@ -50,11 +50,11 @@ public class OllamaChatRequestBuilder {
|
|||||||
}
|
}
|
||||||
}).collect(Collectors.toList());
|
}).collect(Collectors.toList());
|
||||||
|
|
||||||
messages.add(new OllamaChatMessage(role, content, binaryImages));
|
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) {
|
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List<OllamaChatToolCalls> toolCalls, String... imageUrls) {
|
||||||
List<OllamaChatMessage> messages = this.request.getMessages();
|
List<OllamaChatMessage> messages = this.request.getMessages();
|
||||||
List<byte[]> binaryImages = null;
|
List<byte[]> binaryImages = null;
|
||||||
if (imageUrls.length > 0) {
|
if (imageUrls.length > 0) {
|
||||||
@ -70,7 +70,7 @@ public class OllamaChatRequestBuilder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.add(new OllamaChatMessage(role, content, binaryImages));
|
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
package io.github.ollama4j.models.chat;
|
||||||
|
|
||||||
|
import io.github.ollama4j.tools.OllamaToolCallsFunction;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class OllamaChatToolCalls {
|
||||||
|
|
||||||
|
private OllamaToolCallsFunction function;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
@ -17,7 +17,7 @@ public class OllamaResult {
|
|||||||
*
|
*
|
||||||
* @return String completion/response text
|
* @return String completion/response text
|
||||||
*/
|
*/
|
||||||
private final String response;
|
private final String content;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* -- GETTER --
|
* -- GETTER --
|
||||||
@ -35,8 +35,8 @@ public class OllamaResult {
|
|||||||
*/
|
*/
|
||||||
private long responseTime = 0;
|
private long responseTime = 0;
|
||||||
|
|
||||||
public OllamaResult(String response, long responseTime, int httpStatusCode) {
|
public OllamaResult(String content, long responseTime, int httpStatusCode) {
|
||||||
this.response = response;
|
this.content = content;
|
||||||
this.responseTime = responseTime;
|
this.responseTime = responseTime;
|
||||||
this.httpStatusCode = httpStatusCode;
|
this.httpStatusCode = httpStatusCode;
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
package io.github.ollama4j.tools;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class OllamaToolCallsFunction
|
||||||
|
{
|
||||||
|
private String name;
|
||||||
|
private Map<String,String> arguments;
|
||||||
|
}
|
@ -10,6 +10,8 @@ import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
|||||||
import io.github.ollama4j.models.chat.OllamaChatResult;
|
import io.github.ollama4j.models.chat.OllamaChatResult;
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||||
|
import io.github.ollama4j.tools.ToolFunction;
|
||||||
|
import io.github.ollama4j.tools.Tools;
|
||||||
import io.github.ollama4j.utils.OptionsBuilder;
|
import io.github.ollama4j.utils.OptionsBuilder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
@ -24,9 +26,7 @@ import java.io.InputStream;
|
|||||||
import java.net.ConnectException;
|
import java.net.ConnectException;
|
||||||
import java.net.URISyntaxException;
|
import java.net.URISyntaxException;
|
||||||
import java.net.http.HttpConnectTimeoutException;
|
import java.net.http.HttpConnectTimeoutException;
|
||||||
import java.util.List;
|
import java.util.*;
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Properties;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@ -230,18 +230,47 @@ class TestRealAPIs {
|
|||||||
void testChatWithTools() {
|
void testChatWithTools() {
|
||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
|
ollamaAPI.setVerbose(true);
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
|
|
||||||
"You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
|
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
||||||
|
.functionName("get-employee-details")
|
||||||
|
.functionDescription("Get employee details from the database")
|
||||||
|
.toolPrompt(
|
||||||
|
Tools.PromptFuncDefinition.builder().type("function").function(
|
||||||
|
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||||
|
.name("get-employee-details")
|
||||||
|
.description("Get employee details from the database")
|
||||||
|
.parameters(
|
||||||
|
Tools.PromptFuncDefinition.Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(
|
||||||
|
new Tools.PropsBuilder()
|
||||||
|
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
|
||||||
|
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
|
||||||
|
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.required(List.of("employee-name"))
|
||||||
|
.build()
|
||||||
|
).build()
|
||||||
|
).build()
|
||||||
|
)
|
||||||
|
.toolFunction(new DBQueryFunction())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ollamaAPI.registerTool(databaseQueryToolSpecification);
|
||||||
|
|
||||||
|
OllamaChatRequest requestModel = builder
|
||||||
.withMessage(OllamaChatMessageRole.USER,
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
"What is the capital of France? And what's France's connection with Mona Lisa?")
|
"Give me the details of the employee named 'Rahul Kumar'?")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
|
System.err.println("Response: " + chatResult);
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
assertFalse(chatResult.getResponse().isBlank());
|
assertFalse(chatResult.getResponse().isBlank());
|
||||||
assertTrue(chatResult.getResponse().startsWith("NI"));
|
assertEquals(2, chatResult.getChatHistory().size());
|
||||||
assertEquals(3, chatResult.getChatHistory().size());
|
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
fail(e);
|
fail(e);
|
||||||
}
|
}
|
||||||
@ -402,6 +431,14 @@ class TestRealAPIs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class DBQueryFunction implements ToolFunction {
|
||||||
|
@Override
|
||||||
|
public Object apply(Map<String, Object> arguments) {
|
||||||
|
// perform DB operations here
|
||||||
|
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name").toString(), arguments.get("employee-address").toString(), arguments.get("employee-phone").toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
class Config {
|
class Config {
|
||||||
private String ollamaURL;
|
private String ollamaURL;
|
||||||
@ -426,4 +463,6 @@ class Config {
|
|||||||
throw new RuntimeException("Error loading properties", e);
|
throw new RuntimeException("Error loading properties", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user