mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 11:57:12 +02:00
Merge branch 'main' of https://github.com/ollama4j/ollama4j
This commit is contained in:
commit
f27bea11d5
@ -283,6 +283,8 @@ If you like or are using this project to build your own, please give us a star.
|
|||||||
| 7 | Katie Backend | An open-source AI-based question-answering platform for accessing private domain knowledge | [GitHub](https://github.com/wyona/katie-backend) |
|
| 7 | Katie Backend | An open-source AI-based question-answering platform for accessing private domain knowledge | [GitHub](https://github.com/wyona/katie-backend) |
|
||||||
| 8 | TeleLlama3 Bot | A question-answering Telegram bot | [Repo](https://git.hiast.edu.sy/mohamadbashar.disoki/telellama3-bot) |
|
| 8 | TeleLlama3 Bot | A question-answering Telegram bot | [Repo](https://git.hiast.edu.sy/mohamadbashar.disoki/telellama3-bot) |
|
||||||
| 9 | moqui-wechat | A moqui-wechat component | [GitHub](https://github.com/heguangyong/moqui-wechat) |
|
| 9 | moqui-wechat | A moqui-wechat component | [GitHub](https://github.com/heguangyong/moqui-wechat) |
|
||||||
|
| 10 | B4X | A set of simple and powerful RAD tool for Desktop and Server development | [Website](https://www.b4x.com/android/forum/threads/ollama4j-library-pnd_ollama4j-your-local-offline-llm-like-chatgpt.165003/) |
|
||||||
|
|
||||||
|
|
||||||
## Traction
|
## Traction
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ public class Main {
|
|||||||
// start conversation with model
|
// start conversation with model
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
|
|
||||||
System.out.println("First answer: " + chatResult.getResponse());
|
System.out.println("First answer: " + chatResult.getResponseModel().getMessage().getContent());
|
||||||
|
|
||||||
// create next userQuestion
|
// create next userQuestion
|
||||||
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is the second largest city?").build();
|
requestModel = builder.withMessages(chatResult.getChatHistory()).withMessage(OllamaChatMessageRole.USER, "And what is the second largest city?").build();
|
||||||
@ -41,7 +41,7 @@ public class Main {
|
|||||||
// "continue" conversation with model
|
// "continue" conversation with model
|
||||||
chatResult = ollamaAPI.chat(requestModel);
|
chatResult = ollamaAPI.chat(requestModel);
|
||||||
|
|
||||||
System.out.println("Second answer: " + chatResult.getResponse());
|
System.out.println("Second answer: " + chatResult.getResponseModel().getMessage().getContent());
|
||||||
|
|
||||||
System.out.println("Chat History: " + chatResult.getChatHistory());
|
System.out.println("Chat History: " + chatResult.getChatHistory());
|
||||||
}
|
}
|
||||||
@ -205,7 +205,7 @@ public class Main {
|
|||||||
// start conversation with model
|
// start conversation with model
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
|
|
||||||
System.out.println(chatResult.getResponse());
|
System.out.println(chatResult.getResponseModel());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ public class Main {
|
|||||||
new File("/path/to/image"))).build();
|
new File("/path/to/image"))).build();
|
||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
System.out.println("First answer: " + chatResult.getResponse());
|
System.out.println("First answer: " + chatResult.getResponseModel());
|
||||||
|
|
||||||
builder.reset();
|
builder.reset();
|
||||||
|
|
||||||
@ -254,7 +254,7 @@ public class Main {
|
|||||||
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
|
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
|
||||||
|
|
||||||
chatResult = ollamaAPI.chat(requestModel);
|
chatResult = ollamaAPI.chat(requestModel);
|
||||||
System.out.println("Second answer: " + chatResult.getResponse());
|
System.out.println("Second answer: " + chatResult.getResponseModel());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -345,21 +345,274 @@ Rahul Kumar, Address: King St, Hyderabad, India, Phone: 9876543210}`
|
|||||||
|
|
||||||
::::
|
::::
|
||||||
|
|
||||||
### Potential Improvements
|
### Using tools in Chat-API
|
||||||
|
|
||||||
Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool
|
Instead of using the specific `ollamaAPI.generateWithTools` method to call the generate API of ollama with tools, it is
|
||||||
registration. For example:
|
also possible to register Tools for the `ollamaAPI.chat` methods. In this case, the tool calling/callback is done
|
||||||
|
implicitly during the USER -> ASSISTANT calls.
|
||||||
|
|
||||||
|
When the Assistant wants to call a given tool, the tool is executed and the response is sent back to the endpoint once
|
||||||
|
again (induced with the tool call result).
|
||||||
|
|
||||||
|
#### Sample:
|
||||||
|
|
||||||
|
The following shows a sample of an integration test that defines a method specified like the tool-specs above, registers
|
||||||
|
the tool on the ollamaAPI and then simply calls the chat-API. All intermediate tool calling is wrapped inside the api
|
||||||
|
call.
|
||||||
|
|
||||||
```java
|
```java
|
||||||
|
public static void main(String[] args) {
|
||||||
|
OllamaAPI ollamaAPI = new OllamaAPI("http://localhost:11434");
|
||||||
|
ollamaAPI.setVerbose(true);
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("llama3.2:1b");
|
||||||
|
|
||||||
@ToolSpec(name = "current-fuel-price", desc = "Get current fuel price")
|
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
||||||
public String getCurrentFuelPrice(Map<String, Object> arguments) {
|
.functionName("get-employee-details")
|
||||||
String location = arguments.get("location").toString();
|
.functionDescription("Get employee details from the database")
|
||||||
String fuelType = arguments.get("fuelType").toString();
|
.toolPrompt(
|
||||||
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
|
Tools.PromptFuncDefinition.builder().type("function").function(
|
||||||
|
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||||
|
.name("get-employee-details")
|
||||||
|
.description("Get employee details from the database")
|
||||||
|
.parameters(
|
||||||
|
Tools.PromptFuncDefinition.Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(
|
||||||
|
new Tools.PropsBuilder()
|
||||||
|
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
|
||||||
|
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
|
||||||
|
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.required(List.of("employee-name"))
|
||||||
|
.build()
|
||||||
|
).build()
|
||||||
|
).build()
|
||||||
|
)
|
||||||
|
.toolFunction(new DBQueryFunction())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ollamaAPI.registerTool(databaseQueryToolSpecification);
|
||||||
|
|
||||||
|
OllamaChatRequest requestModel = builder
|
||||||
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
|
"Give me the ID of the employee named 'Rahul Kumar'?")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
A typical final response of the above could be:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"chatHistory" : [
|
||||||
|
{
|
||||||
|
"role" : "user",
|
||||||
|
"content" : "Give me the ID of the employee named 'Rahul Kumar'?",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : [ ]
|
||||||
|
}, {
|
||||||
|
"role" : "assistant",
|
||||||
|
"content" : "",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : [ {
|
||||||
|
"function" : {
|
||||||
|
"name" : "get-employee-details",
|
||||||
|
"arguments" : {
|
||||||
|
"employee-name" : "Rahul Kumar"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} ]
|
||||||
|
}, {
|
||||||
|
"role" : "tool",
|
||||||
|
"content" : "[TOOL_RESULTS]get-employee-details([employee-name]) : Employee Details {ID: b4bf186c-2ee1-44cc-8856-53b8b6a50f85, Name: Rahul Kumar, Address: null, Phone: null}[/TOOL_RESULTS]",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : null
|
||||||
|
}, {
|
||||||
|
"role" : "assistant",
|
||||||
|
"content" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : null
|
||||||
|
} ],
|
||||||
|
"responseModel" : {
|
||||||
|
"model" : "llama3.2:1b",
|
||||||
|
"message" : {
|
||||||
|
"role" : "assistant",
|
||||||
|
"content" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : null
|
||||||
|
},
|
||||||
|
"done" : true,
|
||||||
|
"error" : null,
|
||||||
|
"context" : null,
|
||||||
|
"created_at" : "2024-12-09T22:23:00.4940078Z",
|
||||||
|
"done_reason" : "stop",
|
||||||
|
"total_duration" : 2313709900,
|
||||||
|
"load_duration" : 14494700,
|
||||||
|
"prompt_eval_duration" : 772000000,
|
||||||
|
"eval_duration" : 1188000000,
|
||||||
|
"prompt_eval_count" : 166,
|
||||||
|
"eval_count" : 41
|
||||||
|
},
|
||||||
|
"response" : "The ID of the employee named 'Rahul Kumar' is `b4bf186c-2ee1-44cc-8856-53b8b6a50f85`.",
|
||||||
|
"httpStatusCode" : 200,
|
||||||
|
"responseTime" : 2313709900
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This tool calling can also be done using the streaming API.
|
||||||
|
|
||||||
|
### Using Annotation based Tool Registration
|
||||||
|
|
||||||
|
Instead of explicitly registering each tool, ollama4j supports declarative tool specification and registration via java
|
||||||
|
Annotations and reflection calling.
|
||||||
|
|
||||||
|
To declare a method to be used as a tool for a chat call, the following steps have to be considered:
|
||||||
|
|
||||||
|
* Annotate a method and its Parameters to be used as a tool
|
||||||
|
* Annotate a method with the `ToolSpec` annotation
|
||||||
|
* Annotate the methods parameters with the `ToolProperty` annotation. Only the following datatypes are supported for now:
|
||||||
|
* `java.lang.String`
|
||||||
|
* `java.lang.Integer`
|
||||||
|
* `java.lang.Boolean`
|
||||||
|
* `java.math.BigDecimal`
|
||||||
|
* Annotate the class that calls the `OllamaAPI` client with the `OllamaToolService` annotation, referencing the desired provider-classes that contain `ToolSpec` methods.
|
||||||
|
* Before calling the `OllamaAPI` chat request, call the method `OllamaAPI.registerAnnotatedTools()` method to add tools to the chat.
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
Let's say, we have an ollama4j service class that should ask a llm a specific tool based question.
|
||||||
|
|
||||||
|
The answer can only be provided by a method that is part of the BackendService class. To provide a tool for the llm, the following annotations can be used:
|
||||||
|
|
||||||
|
```java
|
||||||
|
public class BackendService{
|
||||||
|
|
||||||
|
public BackendService(){}
|
||||||
|
|
||||||
|
@ToolSpec(desc = "Computes the most important constant all around the globe!")
|
||||||
|
public String computeMkeConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){
|
||||||
|
return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The caller API can then be written as:
|
||||||
|
```java
|
||||||
|
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||||
|
|
||||||
|
@OllamaToolService(providers = BackendService.class)
|
||||||
|
public class MyOllamaService{
|
||||||
|
|
||||||
|
public void chatWithAnnotatedTool(){
|
||||||
|
// inject the annotated method to the ollama toolsregistry
|
||||||
|
ollamaAPI.registerAnnotatedTools();
|
||||||
|
|
||||||
|
OllamaChatRequest requestModel = builder
|
||||||
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
|
"Compute the most important constant in the world using 5 digits")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The request should be the following:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model" : "llama3.2:1b",
|
||||||
|
"stream" : false,
|
||||||
|
"messages" : [ {
|
||||||
|
"role" : "user",
|
||||||
|
"content" : "Compute the most important constant in the world using 5 digits",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : [ ]
|
||||||
|
} ],
|
||||||
|
"tools" : [ {
|
||||||
|
"type" : "function",
|
||||||
|
"function" : {
|
||||||
|
"name" : "computeImportantConstant",
|
||||||
|
"description" : "Computes the most important constant all around the globe!",
|
||||||
|
"parameters" : {
|
||||||
|
"type" : "object",
|
||||||
|
"properties" : {
|
||||||
|
"noOfDigits" : {
|
||||||
|
"type" : "java.lang.Integer",
|
||||||
|
"description" : "Number of digits that shall be returned"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required" : [ "noOfDigits" ]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} ]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The result could be something like the following:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"chatHistory" : [ {
|
||||||
|
"role" : "user",
|
||||||
|
"content" : "Compute the most important constant in the world using 5 digits",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : [ ]
|
||||||
|
}, {
|
||||||
|
"role" : "assistant",
|
||||||
|
"content" : "",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : [ {
|
||||||
|
"function" : {
|
||||||
|
"name" : "computeImportantConstant",
|
||||||
|
"arguments" : {
|
||||||
|
"noOfDigits" : "5"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} ]
|
||||||
|
}, {
|
||||||
|
"role" : "tool",
|
||||||
|
"content" : "[TOOL_RESULTS]computeImportantConstant([noOfDigits]) : 1.51019[/TOOL_RESULTS]",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : null
|
||||||
|
}, {
|
||||||
|
"role" : "assistant",
|
||||||
|
"content" : "The most important constant in the world with 5 digits is: **1.51019**",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : null
|
||||||
|
} ],
|
||||||
|
"responseModel" : {
|
||||||
|
"model" : "llama3.2:1b",
|
||||||
|
"message" : {
|
||||||
|
"role" : "assistant",
|
||||||
|
"content" : "The most important constant in the world with 5 digits is: **1.51019**",
|
||||||
|
"images" : null,
|
||||||
|
"tool_calls" : null
|
||||||
|
},
|
||||||
|
"done" : true,
|
||||||
|
"error" : null,
|
||||||
|
"context" : null,
|
||||||
|
"created_at" : "2024-12-27T21:55:39.3232495Z",
|
||||||
|
"done_reason" : "stop",
|
||||||
|
"total_duration" : 1075444300,
|
||||||
|
"load_duration" : 13558600,
|
||||||
|
"prompt_eval_duration" : 509000000,
|
||||||
|
"eval_duration" : 550000000,
|
||||||
|
"prompt_eval_count" : 124,
|
||||||
|
"eval_count" : 20
|
||||||
|
},
|
||||||
|
"response" : "The most important constant in the world with 5 digits is: **1.51019**",
|
||||||
|
"responseTime" : 1075444300,
|
||||||
|
"httpStatusCode" : 200
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Potential Improvements
|
||||||
|
|
||||||
Instead of passing a map of args `Map<String, Object> arguments` to the tool functions, we could support passing
|
Instead of passing a map of args `Map<String, Object> arguments` to the tool functions, we could support passing
|
||||||
specific args separately with their data types. For example:
|
specific args separately with their data types. For example:
|
||||||
|
|
||||||
|
@ -15,11 +15,17 @@ import io.github.ollama4j.models.ps.ModelsProcessResponse;
|
|||||||
import io.github.ollama4j.models.request.*;
|
import io.github.ollama4j.models.request.*;
|
||||||
import io.github.ollama4j.models.response.*;
|
import io.github.ollama4j.models.response.*;
|
||||||
import io.github.ollama4j.tools.*;
|
import io.github.ollama4j.tools.*;
|
||||||
|
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||||
|
import io.github.ollama4j.tools.annotations.ToolProperty;
|
||||||
|
import io.github.ollama4j.tools.annotations.ToolSpec;
|
||||||
import io.github.ollama4j.utils.Options;
|
import io.github.ollama4j.utils.Options;
|
||||||
import io.github.ollama4j.utils.Utils;
|
import io.github.ollama4j.utils.Utils;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
|
import java.lang.reflect.InvocationTargetException;
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.lang.reflect.Parameter;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.net.URISyntaxException;
|
import java.net.URISyntaxException;
|
||||||
import java.net.http.HttpClient;
|
import java.net.http.HttpClient;
|
||||||
@ -59,6 +65,10 @@ public class OllamaAPI {
|
|||||||
*/
|
*/
|
||||||
@Setter
|
@Setter
|
||||||
private boolean verbose = true;
|
private boolean verbose = true;
|
||||||
|
|
||||||
|
@Setter
|
||||||
|
private int maxChatToolCallRetries = 3;
|
||||||
|
|
||||||
private BasicAuth basicAuth;
|
private BasicAuth basicAuth;
|
||||||
|
|
||||||
private final ToolRegistry toolRegistry = new ToolRegistry();
|
private final ToolRegistry toolRegistry = new ToolRegistry();
|
||||||
@ -193,7 +203,7 @@ public class OllamaAPI {
|
|||||||
Elements modelSections = doc.selectXpath("//*[@id='repo']/ul/li/a");
|
Elements modelSections = doc.selectXpath("//*[@id='repo']/ul/li/a");
|
||||||
for (Element e : modelSections) {
|
for (Element e : modelSections) {
|
||||||
LibraryModel model = new LibraryModel();
|
LibraryModel model = new LibraryModel();
|
||||||
Elements names = e.select("div > h2 > span");
|
Elements names = e.select("div > h2 > div > span");
|
||||||
Elements desc = e.select("div > p");
|
Elements desc = e.select("div > p");
|
||||||
Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type");
|
Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type");
|
||||||
Elements popularTags = e.select("div > div > span");
|
Elements popularTags = e.select("div > div > span");
|
||||||
@ -599,6 +609,15 @@ public class OllamaAPI {
|
|||||||
OllamaToolsResult toolResult = new OllamaToolsResult();
|
OllamaToolsResult toolResult = new OllamaToolsResult();
|
||||||
Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
|
Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
|
||||||
|
|
||||||
|
if(!prompt.startsWith("[AVAILABLE_TOOLS]")){
|
||||||
|
final Tools.PromptBuilder promptBuilder = new Tools.PromptBuilder();
|
||||||
|
for(Tools.ToolSpecification spec : toolRegistry.getRegisteredSpecs()) {
|
||||||
|
promptBuilder.withToolSpecification(spec);
|
||||||
|
}
|
||||||
|
promptBuilder.withPrompt(prompt);
|
||||||
|
prompt = promptBuilder.build();
|
||||||
|
}
|
||||||
|
|
||||||
OllamaResult result = generate(model, prompt, raw, options, null);
|
OllamaResult result = generate(model, prompt, raw, options, null);
|
||||||
toolResult.setModelResult(result);
|
toolResult.setModelResult(result);
|
||||||
|
|
||||||
@ -767,18 +786,130 @@ public class OllamaAPI {
|
|||||||
*/
|
*/
|
||||||
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||||
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||||
OllamaResult result;
|
OllamaChatResult result;
|
||||||
|
|
||||||
|
// add all registered tools to Request
|
||||||
|
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
|
||||||
|
|
||||||
if (streamHandler != null) {
|
if (streamHandler != null) {
|
||||||
request.setStream(true);
|
request.setStream(true);
|
||||||
result = requestCaller.call(request, streamHandler);
|
result = requestCaller.call(request, streamHandler);
|
||||||
} else {
|
} else {
|
||||||
result = requestCaller.callSync(request);
|
result = requestCaller.callSync(request);
|
||||||
}
|
}
|
||||||
return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
|
|
||||||
|
// check if toolCallIsWanted
|
||||||
|
List<OllamaChatToolCalls> toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||||
|
int toolCallTries = 0;
|
||||||
|
while(toolCalls != null && !toolCalls.isEmpty() && toolCallTries < maxChatToolCallRetries){
|
||||||
|
for (OllamaChatToolCalls toolCall : toolCalls){
|
||||||
|
String toolName = toolCall.getFunction().getName();
|
||||||
|
ToolFunction toolFunction = toolRegistry.getToolFunction(toolName);
|
||||||
|
Map<String, Object> arguments = toolCall.getFunction().getArguments();
|
||||||
|
Object res = toolFunction.apply(arguments);
|
||||||
|
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() +") : " + res + "[/TOOL_RESULTS]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (streamHandler != null) {
|
||||||
|
result = requestCaller.call(request, streamHandler);
|
||||||
|
} else {
|
||||||
|
result = requestCaller.callSync(request);
|
||||||
|
}
|
||||||
|
toolCalls = result.getResponseModel().getMessage().getToolCalls();
|
||||||
|
toolCallTries++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void registerTool(Tools.ToolSpecification toolSpecification) {
|
public void registerTool(Tools.ToolSpecification toolSpecification) {
|
||||||
toolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
|
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void registerAnnotatedTools() {
|
||||||
|
Class<?> callerClass = null;
|
||||||
|
try {
|
||||||
|
callerClass = Class.forName(Thread.currentThread().getStackTrace()[2].getClassName());
|
||||||
|
} catch (ClassNotFoundException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
OllamaToolService ollamaToolServiceAnnotation = callerClass.getDeclaredAnnotation(OllamaToolService.class);
|
||||||
|
if(ollamaToolServiceAnnotation == null) {
|
||||||
|
throw new IllegalStateException(callerClass + " is not annotated as " + OllamaToolService.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
Class<?>[] providers = ollamaToolServiceAnnotation.providers();
|
||||||
|
|
||||||
|
for(Class<?> provider : providers){
|
||||||
|
Method[] methods = provider.getMethods();
|
||||||
|
for(Method m : methods) {
|
||||||
|
ToolSpec toolSpec = m.getDeclaredAnnotation(ToolSpec.class);
|
||||||
|
if(toolSpec == null){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String operationName = !toolSpec.name().isBlank() ? toolSpec.name() : m.getName();
|
||||||
|
String operationDesc = !toolSpec.desc().isBlank() ? toolSpec.desc() : operationName;
|
||||||
|
|
||||||
|
final Tools.PropsBuilder propsBuilder = new Tools.PropsBuilder();
|
||||||
|
LinkedHashMap<String,String> methodParams = new LinkedHashMap<>();
|
||||||
|
for (Parameter parameter : m.getParameters()) {
|
||||||
|
final ToolProperty toolPropertyAnn = parameter.getDeclaredAnnotation(ToolProperty.class);
|
||||||
|
String propType = parameter.getType().getTypeName();
|
||||||
|
if(toolPropertyAnn == null) {
|
||||||
|
methodParams.put(parameter.getName(),null);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
|
||||||
|
methodParams.put(propName,propType);
|
||||||
|
propsBuilder.withProperty(propName,Tools.PromptFuncDefinition.Property.builder()
|
||||||
|
.type(propType)
|
||||||
|
.description(toolPropertyAnn.desc())
|
||||||
|
.required(toolPropertyAnn.required())
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
|
||||||
|
List<String> reqProps = params.entrySet().stream()
|
||||||
|
.filter(e -> e.getValue().isRequired())
|
||||||
|
.map(Map.Entry::getKey)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder()
|
||||||
|
.functionName(operationName)
|
||||||
|
.functionDescription(operationDesc)
|
||||||
|
.toolPrompt(
|
||||||
|
Tools.PromptFuncDefinition.builder().type("function").function(
|
||||||
|
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||||
|
.name(operationName)
|
||||||
|
.description(operationDesc)
|
||||||
|
.parameters(
|
||||||
|
Tools.PromptFuncDefinition.Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(
|
||||||
|
params
|
||||||
|
)
|
||||||
|
.required(reqProps)
|
||||||
|
.build()
|
||||||
|
).build()
|
||||||
|
).build()
|
||||||
|
)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
try {
|
||||||
|
ReflectionalToolFunction reflectionalToolFunction =
|
||||||
|
new ReflectionalToolFunction(provider.getDeclaredConstructor().newInstance()
|
||||||
|
,m
|
||||||
|
,methodParams);
|
||||||
|
|
||||||
|
toolSpecification.setToolFunction(reflectionalToolFunction);
|
||||||
|
toolRegistry.addTool(toolSpecification.getFunctionName(),toolSpecification);
|
||||||
|
} catch (InstantiationException | IllegalAccessException | InvocationTargetException |
|
||||||
|
NoSuchMethodException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -871,7 +1002,7 @@ public class OllamaAPI {
|
|||||||
try {
|
try {
|
||||||
String methodName = toolFunctionCallSpec.getName();
|
String methodName = toolFunctionCallSpec.getName();
|
||||||
Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
|
Map<String, Object> arguments = toolFunctionCallSpec.getArguments();
|
||||||
ToolFunction function = toolRegistry.getFunction(methodName);
|
ToolFunction function = toolRegistry.getToolFunction(methodName);
|
||||||
if (verbose) {
|
if (verbose) {
|
||||||
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
|
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package io.github.ollama4j.models.chat;
|
|||||||
|
|
||||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||||
|
|
||||||
@ -32,6 +33,8 @@ public class OllamaChatMessage {
|
|||||||
@NonNull
|
@NonNull
|
||||||
private String content;
|
private String content;
|
||||||
|
|
||||||
|
private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
|
||||||
|
|
||||||
@JsonSerialize(using = FileToBase64Serializer.class)
|
@JsonSerialize(using = FileToBase64Serializer.class)
|
||||||
private List<byte[]> images;
|
private List<byte[]> images;
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ package io.github.ollama4j.models.chat;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import io.github.ollama4j.models.request.OllamaCommonRequest;
|
import io.github.ollama4j.models.request.OllamaCommonRequest;
|
||||||
|
import io.github.ollama4j.tools.Tools;
|
||||||
import io.github.ollama4j.utils.OllamaRequestBody;
|
import io.github.ollama4j.utils.OllamaRequestBody;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
@ -21,6 +22,8 @@ public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequ
|
|||||||
|
|
||||||
private List<OllamaChatMessage> messages;
|
private List<OllamaChatMessage> messages;
|
||||||
|
|
||||||
|
private List<Tools.PromptFuncDefinition> tools;
|
||||||
|
|
||||||
public OllamaChatRequest() {}
|
public OllamaChatRequest() {}
|
||||||
|
|
||||||
public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
|
public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
|
||||||
|
@ -10,6 +10,7 @@ import java.io.IOException;
|
|||||||
import java.net.URISyntaxException;
|
import java.net.URISyntaxException;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@ -38,23 +39,27 @@ public class OllamaChatRequestBuilder {
|
|||||||
request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
|
request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
|
||||||
}
|
}
|
||||||
|
|
||||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<File> images) {
|
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){
|
||||||
|
return withMessage(role,content, Collections.emptyList());
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
|
||||||
List<OllamaChatMessage> messages = this.request.getMessages();
|
List<OllamaChatMessage> messages = this.request.getMessages();
|
||||||
|
|
||||||
List<byte[]> binaryImages = images.stream().map(file -> {
|
List<byte[]> binaryImages = images.stream().map(file -> {
|
||||||
try {
|
try {
|
||||||
return Files.readAllBytes(file.toPath());
|
return Files.readAllBytes(file.toPath());
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
LOG.warn(String.format("File '%s' could not be accessed, will not add to message!", file.toPath()), e);
|
LOG.warn("File '{}' could not be accessed, will not add to message!", file.toPath(), e);
|
||||||
return new byte[0];
|
return new byte[0];
|
||||||
}
|
}
|
||||||
}).collect(Collectors.toList());
|
}).collect(Collectors.toList());
|
||||||
|
|
||||||
messages.add(new OllamaChatMessage(role, content, binaryImages));
|
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, String... imageUrls) {
|
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content,List<OllamaChatToolCalls> toolCalls, String... imageUrls) {
|
||||||
List<OllamaChatMessage> messages = this.request.getMessages();
|
List<OllamaChatMessage> messages = this.request.getMessages();
|
||||||
List<byte[]> binaryImages = null;
|
List<byte[]> binaryImages = null;
|
||||||
if (imageUrls.length > 0) {
|
if (imageUrls.length > 0) {
|
||||||
@ -63,14 +68,14 @@ public class OllamaChatRequestBuilder {
|
|||||||
try {
|
try {
|
||||||
binaryImages.add(Utils.loadImageBytesFromUrl(imageUrl));
|
binaryImages.add(Utils.loadImageBytesFromUrl(imageUrl));
|
||||||
} catch (URISyntaxException e) {
|
} catch (URISyntaxException e) {
|
||||||
LOG.warn(String.format("URL '%s' could not be accessed, will not add to message!", imageUrl), e);
|
LOG.warn("URL '{}' could not be accessed, will not add to message!", imageUrl, e);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
LOG.warn(String.format("Content of URL '%s' could not be read, will not add to message!", imageUrl), e);
|
LOG.warn("Content of URL '{}' could not be read, will not add to message!", imageUrl, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.add(new OllamaChatMessage(role, content, binaryImages));
|
messages.add(new OllamaChatMessage(role, content,toolCalls, binaryImages));
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,28 +2,54 @@ package io.github.ollama4j.models.chat;
|
|||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import io.github.ollama4j.models.response.OllamaResult;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
|
* Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
|
||||||
* {@link OllamaChatMessageRole#ASSISTANT} role.
|
* {@link OllamaChatMessageRole#ASSISTANT} role.
|
||||||
*/
|
*/
|
||||||
public class OllamaChatResult extends OllamaResult {
|
@Getter
|
||||||
|
public class OllamaChatResult {
|
||||||
|
|
||||||
|
|
||||||
private List<OllamaChatMessage> chatHistory;
|
private List<OllamaChatMessage> chatHistory;
|
||||||
|
|
||||||
public OllamaChatResult(String response, long responseTime, int httpStatusCode, List<OllamaChatMessage> chatHistory) {
|
private OllamaChatResponseModel responseModel;
|
||||||
super(response, responseTime, httpStatusCode);
|
|
||||||
|
public OllamaChatResult(OllamaChatResponseModel responseModel, List<OllamaChatMessage> chatHistory) {
|
||||||
this.chatHistory = chatHistory;
|
this.chatHistory = chatHistory;
|
||||||
appendAnswerToChatHistory(response);
|
this.responseModel = responseModel;
|
||||||
|
appendAnswerToChatHistory(responseModel);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<OllamaChatMessage> getChatHistory() {
|
private void appendAnswerToChatHistory(OllamaChatResponseModel response) {
|
||||||
return chatHistory;
|
this.chatHistory.add(response.getMessage());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void appendAnswerToChatHistory(String answer) {
|
@Override
|
||||||
OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer);
|
public String toString() {
|
||||||
this.chatHistory.add(assistantMessage);
|
try {
|
||||||
|
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
|
public String getResponse(){
|
||||||
|
return responseModel != null ? responseModel.getMessage().getContent() : "";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
|
public int getHttpStatusCode(){
|
||||||
|
return 200;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
|
public long getResponseTime(){
|
||||||
|
return responseModel != null ? responseModel.getTotalDuration() : 0L;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
package io.github.ollama4j.models.chat;
|
||||||
|
|
||||||
|
import io.github.ollama4j.tools.OllamaToolCallsFunction;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class OllamaChatToolCalls {
|
||||||
|
|
||||||
|
private OllamaToolCallsFunction function;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,40 @@
|
|||||||
|
package io.github.ollama4j.models.embeddings;
|
||||||
|
|
||||||
|
import io.github.ollama4j.utils.Options;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builderclass to easily create Requests for Embedding models using ollama.
|
||||||
|
*/
|
||||||
|
public class OllamaEmbedRequestBuilder {
|
||||||
|
|
||||||
|
private final OllamaEmbedRequestModel request;
|
||||||
|
|
||||||
|
private OllamaEmbedRequestBuilder(String model, List<String> input) {
|
||||||
|
this.request = new OllamaEmbedRequestModel(model,input);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static OllamaEmbedRequestBuilder getInstance(String model, String... input){
|
||||||
|
return new OllamaEmbedRequestBuilder(model, List.of(input));
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaEmbedRequestBuilder withOptions(Options options){
|
||||||
|
this.request.setOptions(options.getOptionsMap());
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaEmbedRequestBuilder withKeepAlive(String keepAlive){
|
||||||
|
this.request.setKeepAlive(keepAlive);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaEmbedRequestBuilder withoutTruncate(){
|
||||||
|
this.request.setTruncate(false);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaEmbedRequestModel build() {
|
||||||
|
return this.request;
|
||||||
|
}
|
||||||
|
}
|
@ -7,6 +7,7 @@ import lombok.Data;
|
|||||||
|
|
||||||
@SuppressWarnings("unused")
|
@SuppressWarnings("unused")
|
||||||
@Data
|
@Data
|
||||||
|
@Deprecated(since="1.0.90")
|
||||||
public class OllamaEmbeddingResponseModel {
|
public class OllamaEmbeddingResponseModel {
|
||||||
@JsonProperty("embedding")
|
@JsonProperty("embedding")
|
||||||
private List<Double> embedding;
|
private List<Double> embedding;
|
||||||
|
@ -2,6 +2,7 @@ package io.github.ollama4j.models.embeddings;
|
|||||||
|
|
||||||
import io.github.ollama4j.utils.Options;
|
import io.github.ollama4j.utils.Options;
|
||||||
|
|
||||||
|
@Deprecated(since="1.0.90")
|
||||||
public class OllamaEmbeddingsRequestBuilder {
|
public class OllamaEmbeddingsRequestBuilder {
|
||||||
|
|
||||||
private OllamaEmbeddingsRequestBuilder(String model, String prompt){
|
private OllamaEmbeddingsRequestBuilder(String model, String prompt){
|
||||||
|
@ -12,6 +12,7 @@ import lombok.RequiredArgsConstructor;
|
|||||||
@Data
|
@Data
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@Deprecated(since="1.0.90")
|
||||||
public class OllamaEmbeddingsRequestModel {
|
public class OllamaEmbeddingsRequestModel {
|
||||||
@NonNull
|
@NonNull
|
||||||
private String model;
|
private String model;
|
||||||
|
@ -1,17 +1,26 @@
|
|||||||
package io.github.ollama4j.models.request;
|
package io.github.ollama4j.models.request;
|
||||||
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.core.type.TypeReference;
|
||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||||
import io.github.ollama4j.models.response.OllamaResult;
|
import io.github.ollama4j.models.chat.*;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
|
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
|
|
||||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||||
import io.github.ollama4j.utils.OllamaRequestBody;
|
import io.github.ollama4j.tools.Tools;
|
||||||
import io.github.ollama4j.utils.Utils;
|
import io.github.ollama4j.utils.Utils;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.http.HttpClient;
|
||||||
|
import java.net.http.HttpRequest;
|
||||||
|
import java.net.http.HttpResponse;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specialization class for requests
|
* Specialization class for requests
|
||||||
@ -31,14 +40,30 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
|||||||
return "/api/chat";
|
return "/api/chat";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses streamed Response line from ollama chat.
|
||||||
|
* Using {@link com.fasterxml.jackson.databind.ObjectMapper#readValue(String, TypeReference)} should throw
|
||||||
|
* {@link IllegalArgumentException} in case of null line or {@link com.fasterxml.jackson.core.JsonParseException}
|
||||||
|
* in case the JSON Object cannot be parsed to a {@link OllamaChatResponseModel}. Thus, the ResponseModel should
|
||||||
|
* never be null.
|
||||||
|
*
|
||||||
|
* @param line streamed line of ollama stream response
|
||||||
|
* @param responseBuffer Stringbuffer to add latest response message part to
|
||||||
|
* @return TRUE, if ollama-Response has 'done' state
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
|
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
|
||||||
try {
|
try {
|
||||||
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
|
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
|
||||||
responseBuffer.append(ollamaResponseModel.getMessage().getContent());
|
// it seems that under heavy load ollama responds with an empty chat message part in the streamed response
|
||||||
|
// thus, we null check the message and hope that the next streamed response has some message content again
|
||||||
|
OllamaChatMessage message = ollamaResponseModel.getMessage();
|
||||||
|
if(message != null) {
|
||||||
|
responseBuffer.append(message.getContent());
|
||||||
if (streamObserver != null) {
|
if (streamObserver != null) {
|
||||||
streamObserver.notify(ollamaResponseModel);
|
streamObserver.notify(ollamaResponseModel);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return ollamaResponseModel.isDone();
|
return ollamaResponseModel.isDone();
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
LOG.error("Error parsing the Ollama chat response!", e);
|
LOG.error("Error parsing the Ollama chat response!", e);
|
||||||
@ -46,9 +71,75 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
streamObserver = new OllamaChatStreamObserver(streamHandler);
|
streamObserver = new OllamaChatStreamObserver(streamHandler);
|
||||||
return super.callSync(body);
|
return callSync(body);
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaChatResult callSync(OllamaChatRequest body) throws OllamaBaseException, IOException, InterruptedException {
|
||||||
|
// Create Request
|
||||||
|
HttpClient httpClient = HttpClient.newHttpClient();
|
||||||
|
URI uri = URI.create(getHost() + getEndpointSuffix());
|
||||||
|
HttpRequest.Builder requestBuilder =
|
||||||
|
getRequestBuilderDefault(uri)
|
||||||
|
.POST(
|
||||||
|
body.getBodyPublisher());
|
||||||
|
HttpRequest request = requestBuilder.build();
|
||||||
|
if (isVerbose()) LOG.info("Asking model: " + body.toString());
|
||||||
|
HttpResponse<InputStream> response =
|
||||||
|
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||||
|
|
||||||
|
int statusCode = response.statusCode();
|
||||||
|
InputStream responseBodyStream = response.body();
|
||||||
|
StringBuilder responseBuffer = new StringBuilder();
|
||||||
|
OllamaChatResponseModel ollamaChatResponseModel = null;
|
||||||
|
List<OllamaChatToolCalls> wantedToolsForStream = null;
|
||||||
|
try (BufferedReader reader =
|
||||||
|
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||||
|
|
||||||
|
String line;
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
if (statusCode == 404) {
|
||||||
|
LOG.warn("Status code: 404 (Not Found)");
|
||||||
|
OllamaErrorResponse ollamaResponseModel =
|
||||||
|
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else if (statusCode == 401) {
|
||||||
|
LOG.warn("Status code: 401 (Unauthorized)");
|
||||||
|
OllamaErrorResponse ollamaResponseModel =
|
||||||
|
Utils.getObjectMapper()
|
||||||
|
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else if (statusCode == 400) {
|
||||||
|
LOG.warn("Status code: 400 (Bad Request)");
|
||||||
|
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
|
||||||
|
OllamaErrorResponse.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else {
|
||||||
|
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
||||||
|
ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
|
||||||
|
if(body.stream && ollamaChatResponseModel.getMessage().getToolCalls() != null){
|
||||||
|
wantedToolsForStream = ollamaChatResponseModel.getMessage().getToolCalls();
|
||||||
|
}
|
||||||
|
if (finished && body.stream) {
|
||||||
|
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (statusCode != 200) {
|
||||||
|
LOG.error("Status code " + statusCode);
|
||||||
|
throw new OllamaBaseException(responseBuffer.toString());
|
||||||
|
} else {
|
||||||
|
if(wantedToolsForStream != null) {
|
||||||
|
ollamaChatResponseModel.getMessage().setToolCalls(wantedToolsForStream);
|
||||||
|
}
|
||||||
|
OllamaChatResult ollamaResult =
|
||||||
|
new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
|
||||||
|
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
|
||||||
|
return ollamaResult;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import io.github.ollama4j.models.response.OllamaErrorResponse;
|
|||||||
import io.github.ollama4j.models.response.OllamaResult;
|
import io.github.ollama4j.models.response.OllamaResult;
|
||||||
import io.github.ollama4j.utils.OllamaRequestBody;
|
import io.github.ollama4j.utils.OllamaRequestBody;
|
||||||
import io.github.ollama4j.utils.Utils;
|
import io.github.ollama4j.utils.Utils;
|
||||||
|
import lombok.Getter;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
@ -24,14 +25,15 @@ import java.util.Base64;
|
|||||||
/**
|
/**
|
||||||
* Abstract helperclass to call the ollama api server.
|
* Abstract helperclass to call the ollama api server.
|
||||||
*/
|
*/
|
||||||
|
@Getter
|
||||||
public abstract class OllamaEndpointCaller {
|
public abstract class OllamaEndpointCaller {
|
||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
|
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
|
||||||
|
|
||||||
private String host;
|
private final String host;
|
||||||
private BasicAuth basicAuth;
|
private final BasicAuth basicAuth;
|
||||||
private long requestTimeoutSeconds;
|
private final long requestTimeoutSeconds;
|
||||||
private boolean verbose;
|
private final boolean verbose;
|
||||||
|
|
||||||
public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
||||||
this.host = host;
|
this.host = host;
|
||||||
@ -45,80 +47,13 @@ public abstract class OllamaEndpointCaller {
|
|||||||
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
|
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
|
|
||||||
*
|
|
||||||
* @param body POST body payload
|
|
||||||
* @return result answer given by the assistant
|
|
||||||
* @throws OllamaBaseException any response code than 200 has been returned
|
|
||||||
* @throws IOException in case the responseStream can not be read
|
|
||||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
|
||||||
*/
|
|
||||||
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
|
|
||||||
// Create Request
|
|
||||||
long startTime = System.currentTimeMillis();
|
|
||||||
HttpClient httpClient = HttpClient.newHttpClient();
|
|
||||||
URI uri = URI.create(this.host + getEndpointSuffix());
|
|
||||||
HttpRequest.Builder requestBuilder =
|
|
||||||
getRequestBuilderDefault(uri)
|
|
||||||
.POST(
|
|
||||||
body.getBodyPublisher());
|
|
||||||
HttpRequest request = requestBuilder.build();
|
|
||||||
if (this.verbose) LOG.info("Asking model: " + body.toString());
|
|
||||||
HttpResponse<InputStream> response =
|
|
||||||
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
|
||||||
|
|
||||||
int statusCode = response.statusCode();
|
|
||||||
InputStream responseBodyStream = response.body();
|
|
||||||
StringBuilder responseBuffer = new StringBuilder();
|
|
||||||
try (BufferedReader reader =
|
|
||||||
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
|
||||||
String line;
|
|
||||||
while ((line = reader.readLine()) != null) {
|
|
||||||
if (statusCode == 404) {
|
|
||||||
LOG.warn("Status code: 404 (Not Found)");
|
|
||||||
OllamaErrorResponse ollamaResponseModel =
|
|
||||||
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
|
||||||
responseBuffer.append(ollamaResponseModel.getError());
|
|
||||||
} else if (statusCode == 401) {
|
|
||||||
LOG.warn("Status code: 401 (Unauthorized)");
|
|
||||||
OllamaErrorResponse ollamaResponseModel =
|
|
||||||
Utils.getObjectMapper()
|
|
||||||
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
|
|
||||||
responseBuffer.append(ollamaResponseModel.getError());
|
|
||||||
} else if (statusCode == 400) {
|
|
||||||
LOG.warn("Status code: 400 (Bad Request)");
|
|
||||||
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
|
|
||||||
OllamaErrorResponse.class);
|
|
||||||
responseBuffer.append(ollamaResponseModel.getError());
|
|
||||||
} else {
|
|
||||||
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
|
||||||
if (finished) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (statusCode != 200) {
|
|
||||||
LOG.error("Status code " + statusCode);
|
|
||||||
throw new OllamaBaseException(responseBuffer.toString());
|
|
||||||
} else {
|
|
||||||
long endTime = System.currentTimeMillis();
|
|
||||||
OllamaResult ollamaResult =
|
|
||||||
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
|
|
||||||
if (verbose) LOG.info("Model response: " + ollamaResult);
|
|
||||||
return ollamaResult;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get default request builder.
|
* Get default request builder.
|
||||||
*
|
*
|
||||||
* @param uri URI to get a HttpRequest.Builder
|
* @param uri URI to get a HttpRequest.Builder
|
||||||
* @return HttpRequest.Builder
|
* @return HttpRequest.Builder
|
||||||
*/
|
*/
|
||||||
private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
|
protected HttpRequest.Builder getRequestBuilderDefault(URI uri) {
|
||||||
HttpRequest.Builder requestBuilder =
|
HttpRequest.Builder requestBuilder =
|
||||||
HttpRequest.newBuilder(uri)
|
HttpRequest.newBuilder(uri)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
@ -134,7 +69,7 @@ public abstract class OllamaEndpointCaller {
|
|||||||
*
|
*
|
||||||
* @return basic authentication header value (encoded credentials)
|
* @return basic authentication header value (encoded credentials)
|
||||||
*/
|
*/
|
||||||
private String getBasicAuthHeaderValue() {
|
protected String getBasicAuthHeaderValue() {
|
||||||
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
|
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
|
||||||
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
|
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
|
||||||
}
|
}
|
||||||
@ -144,7 +79,7 @@ public abstract class OllamaEndpointCaller {
|
|||||||
*
|
*
|
||||||
* @return true when Basic Auth credentials set
|
* @return true when Basic Auth credentials set
|
||||||
*/
|
*/
|
||||||
private boolean isBasicAuthCredentialsSet() {
|
protected boolean isBasicAuthCredentialsSet() {
|
||||||
return this.basicAuth != null;
|
return this.basicAuth != null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package io.github.ollama4j.models.request;
|
|||||||
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||||
|
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
||||||
import io.github.ollama4j.models.response.OllamaResult;
|
import io.github.ollama4j.models.response.OllamaResult;
|
||||||
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
|
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
|
||||||
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
|
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
|
||||||
@ -11,7 +12,15 @@ import io.github.ollama4j.utils.Utils;
|
|||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.http.HttpClient;
|
||||||
|
import java.net.http.HttpRequest;
|
||||||
|
import java.net.http.HttpResponse;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
|
||||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
||||||
|
|
||||||
@ -46,6 +55,73 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
|||||||
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
|
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
|
||||||
return super.callSync(body);
|
return callSync(body);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
|
||||||
|
*
|
||||||
|
* @param body POST body payload
|
||||||
|
* @return result answer given by the assistant
|
||||||
|
* @throws OllamaBaseException any response code than 200 has been returned
|
||||||
|
* @throws IOException in case the responseStream can not be read
|
||||||
|
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||||
|
*/
|
||||||
|
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
|
||||||
|
// Create Request
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
HttpClient httpClient = HttpClient.newHttpClient();
|
||||||
|
URI uri = URI.create(getHost() + getEndpointSuffix());
|
||||||
|
HttpRequest.Builder requestBuilder =
|
||||||
|
getRequestBuilderDefault(uri)
|
||||||
|
.POST(
|
||||||
|
body.getBodyPublisher());
|
||||||
|
HttpRequest request = requestBuilder.build();
|
||||||
|
if (isVerbose()) LOG.info("Asking model: " + body.toString());
|
||||||
|
HttpResponse<InputStream> response =
|
||||||
|
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||||
|
|
||||||
|
int statusCode = response.statusCode();
|
||||||
|
InputStream responseBodyStream = response.body();
|
||||||
|
StringBuilder responseBuffer = new StringBuilder();
|
||||||
|
try (BufferedReader reader =
|
||||||
|
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||||
|
String line;
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
if (statusCode == 404) {
|
||||||
|
LOG.warn("Status code: 404 (Not Found)");
|
||||||
|
OllamaErrorResponse ollamaResponseModel =
|
||||||
|
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else if (statusCode == 401) {
|
||||||
|
LOG.warn("Status code: 401 (Unauthorized)");
|
||||||
|
OllamaErrorResponse ollamaResponseModel =
|
||||||
|
Utils.getObjectMapper()
|
||||||
|
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else if (statusCode == 400) {
|
||||||
|
LOG.warn("Status code: 400 (Bad Request)");
|
||||||
|
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
|
||||||
|
OllamaErrorResponse.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else {
|
||||||
|
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
||||||
|
if (finished) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (statusCode != 200) {
|
||||||
|
LOG.error("Status code " + statusCode);
|
||||||
|
throw new OllamaBaseException(responseBuffer.toString());
|
||||||
|
} else {
|
||||||
|
long endTime = System.currentTimeMillis();
|
||||||
|
OllamaResult ollamaResult =
|
||||||
|
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
|
||||||
|
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
|
||||||
|
return ollamaResult;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
package io.github.ollama4j.tools;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class OllamaToolCallsFunction
|
||||||
|
{
|
||||||
|
private String name;
|
||||||
|
private Map<String,Object> arguments;
|
||||||
|
}
|
@ -0,0 +1,54 @@
|
|||||||
|
package io.github.ollama4j.tools;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specification of a {@link ToolFunction} that provides the implementation via java reflection calling.
|
||||||
|
*/
|
||||||
|
@Setter
|
||||||
|
@Getter
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class ReflectionalToolFunction implements ToolFunction{
|
||||||
|
|
||||||
|
private Object functionHolder;
|
||||||
|
private Method function;
|
||||||
|
private LinkedHashMap<String,String> propertyDefinition;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object apply(Map<String, Object> arguments) {
|
||||||
|
LinkedHashMap<String, Object> argumentsCopy = new LinkedHashMap<>(this.propertyDefinition);
|
||||||
|
for (Map.Entry<String,String> param : this.propertyDefinition.entrySet()){
|
||||||
|
argumentsCopy.replace(param.getKey(),typeCast(arguments.get(param.getKey()),param.getValue()));
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return function.invoke(functionHolder, argumentsCopy.values().toArray());
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException("Failed to invoke tool: " + function.getName(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Object typeCast(Object inputValue, String className) {
|
||||||
|
if(className == null || inputValue == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
String inputValueString = inputValue.toString();
|
||||||
|
switch (className) {
|
||||||
|
case "java.lang.Integer":
|
||||||
|
return Integer.parseInt(inputValueString);
|
||||||
|
case "java.lang.Boolean":
|
||||||
|
return Boolean.valueOf(inputValueString);
|
||||||
|
case "java.math.BigDecimal":
|
||||||
|
return new BigDecimal(inputValueString);
|
||||||
|
default:
|
||||||
|
return inputValueString;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,16 +1,22 @@
|
|||||||
package io.github.ollama4j.tools;
|
package io.github.ollama4j.tools;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public class ToolRegistry {
|
public class ToolRegistry {
|
||||||
private final Map<String, ToolFunction> functionMap = new HashMap<>();
|
private final Map<String, Tools.ToolSpecification> tools = new HashMap<>();
|
||||||
|
|
||||||
public ToolFunction getFunction(String name) {
|
public ToolFunction getToolFunction(String name) {
|
||||||
return functionMap.get(name);
|
final Tools.ToolSpecification toolSpecification = tools.get(name);
|
||||||
|
return toolSpecification !=null ? toolSpecification.getToolFunction() : null ;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addFunction(String name, ToolFunction function) {
|
public void addTool (String name, Tools.ToolSpecification specification) {
|
||||||
functionMap.put(name, function);
|
tools.put(name, specification);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Collection<Tools.ToolSpecification> getRegisteredSpecs(){
|
||||||
|
return tools.values();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,8 +6,10 @@ import com.fasterxml.jackson.annotation.JsonInclude;
|
|||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import io.github.ollama4j.utils.Utils;
|
import io.github.ollama4j.utils.Utils;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@ -20,17 +22,23 @@ public class Tools {
|
|||||||
public static class ToolSpecification {
|
public static class ToolSpecification {
|
||||||
private String functionName;
|
private String functionName;
|
||||||
private String functionDescription;
|
private String functionDescription;
|
||||||
private Map<String, PromptFuncDefinition.Property> properties;
|
private PromptFuncDefinition toolPrompt;
|
||||||
private ToolFunction toolDefinition;
|
private ToolFunction toolFunction;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
public static class PromptFuncDefinition {
|
public static class PromptFuncDefinition {
|
||||||
private String type;
|
private String type;
|
||||||
private PromptFuncSpec function;
|
private PromptFuncSpec function;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
public static class PromptFuncSpec {
|
public static class PromptFuncSpec {
|
||||||
private String name;
|
private String name;
|
||||||
private String description;
|
private String description;
|
||||||
@ -38,6 +46,9 @@ public class Tools {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
public static class Parameters {
|
public static class Parameters {
|
||||||
private String type;
|
private String type;
|
||||||
private Map<String, Property> properties;
|
private Map<String, Property> properties;
|
||||||
@ -46,6 +57,8 @@ public class Tools {
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
public static class Property {
|
public static class Property {
|
||||||
private String type;
|
private String type;
|
||||||
private String description;
|
private String description;
|
||||||
@ -94,10 +107,10 @@ public class Tools {
|
|||||||
|
|
||||||
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
|
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
|
||||||
parameters.setType("object");
|
parameters.setType("object");
|
||||||
parameters.setProperties(spec.getProperties());
|
parameters.setProperties(spec.getToolPrompt().getFunction().parameters.getProperties());
|
||||||
|
|
||||||
List<String> requiredValues = new ArrayList<>();
|
List<String> requiredValues = new ArrayList<>();
|
||||||
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProperties().entrySet()) {
|
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getToolPrompt().getFunction().getParameters().getProperties().entrySet()) {
|
||||||
if (p.getValue().isRequired()) {
|
if (p.getValue().isRequired()) {
|
||||||
requiredValues.add(p.getKey());
|
requiredValues.add(p.getKey());
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
package io.github.ollama4j.tools.annotations;
|
||||||
|
|
||||||
|
import io.github.ollama4j.OllamaAPI;
|
||||||
|
|
||||||
|
import java.lang.annotation.ElementType;
|
||||||
|
import java.lang.annotation.Retention;
|
||||||
|
import java.lang.annotation.RetentionPolicy;
|
||||||
|
import java.lang.annotation.Target;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Annotates a class that calls {@link io.github.ollama4j.OllamaAPI} such that the Method
|
||||||
|
* {@link OllamaAPI#registerAnnotatedTools()} can be used to auto-register all provided classes (resp. all
|
||||||
|
* contained Methods of the provider classes annotated with {@link ToolSpec}).
|
||||||
|
*/
|
||||||
|
@Target(ElementType.TYPE)
|
||||||
|
@Retention(RetentionPolicy.RUNTIME)
|
||||||
|
public @interface OllamaToolService {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return Classes with no-arg constructor that will be used for tool-registration.
|
||||||
|
*/
|
||||||
|
Class<?>[] providers();
|
||||||
|
}
|
@ -0,0 +1,32 @@
|
|||||||
|
package io.github.ollama4j.tools.annotations;
|
||||||
|
|
||||||
|
import java.lang.annotation.ElementType;
|
||||||
|
import java.lang.annotation.Retention;
|
||||||
|
import java.lang.annotation.RetentionPolicy;
|
||||||
|
import java.lang.annotation.Target;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Annotates a Method Parameter in a {@link ToolSpec} annotated Method. A parameter annotated with this annotation will
|
||||||
|
* be part of the tool description that is sent to the llm for tool-calling.
|
||||||
|
*/
|
||||||
|
@Retention(RetentionPolicy.RUNTIME)
|
||||||
|
@Target(ElementType.PARAMETER)
|
||||||
|
public @interface ToolProperty {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return name of the parameter that is used for the tool description. Has to be set as depending on the caller,
|
||||||
|
* method name backtracking is not possible with reflection.
|
||||||
|
*/
|
||||||
|
String name();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return a detailed description of the parameter. This is used by the llm called to specify, which property has
|
||||||
|
* to be set by the llm and how this should be filled.
|
||||||
|
*/
|
||||||
|
String desc();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return tells the llm that it has to set a value for this property.
|
||||||
|
*/
|
||||||
|
boolean required() default true;
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
package io.github.ollama4j.tools.annotations;
|
||||||
|
|
||||||
|
import io.github.ollama4j.OllamaAPI;
|
||||||
|
|
||||||
|
import java.lang.annotation.ElementType;
|
||||||
|
import java.lang.annotation.Retention;
|
||||||
|
import java.lang.annotation.RetentionPolicy;
|
||||||
|
import java.lang.annotation.Target;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Annotates Methods of classes that should be registered as tools by {@link OllamaAPI#registerAnnotatedTools()}
|
||||||
|
* automatically.
|
||||||
|
*/
|
||||||
|
@Target(ElementType.METHOD)
|
||||||
|
@Retention(RetentionPolicy.RUNTIME)
|
||||||
|
public @interface ToolSpec {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return tool-name that the method should be used as. Defaults to the methods name.
|
||||||
|
*/
|
||||||
|
String name() default "";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return a detailed description of the method that can be interpreted by the llm, whether it should call the tool
|
||||||
|
* or not.
|
||||||
|
*/
|
||||||
|
String desc();
|
||||||
|
}
|
@ -2,14 +2,16 @@ package io.github.ollama4j.integrationtests;
|
|||||||
|
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||||
|
import io.github.ollama4j.models.chat.*;
|
||||||
import io.github.ollama4j.models.response.ModelDetail;
|
import io.github.ollama4j.models.response.ModelDetail;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
|
||||||
import io.github.ollama4j.models.response.OllamaResult;
|
import io.github.ollama4j.models.response.OllamaResult;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
|
||||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
|
||||||
import io.github.ollama4j.models.chat.OllamaChatResult;
|
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||||
|
import io.github.ollama4j.samples.AnnotatedTool;
|
||||||
|
import io.github.ollama4j.tools.OllamaToolCallsFunction;
|
||||||
|
import io.github.ollama4j.tools.ToolFunction;
|
||||||
|
import io.github.ollama4j.tools.Tools;
|
||||||
|
import io.github.ollama4j.tools.annotations.OllamaToolService;
|
||||||
import io.github.ollama4j.utils.OptionsBuilder;
|
import io.github.ollama4j.utils.OptionsBuilder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
@ -24,12 +26,12 @@ import java.io.InputStream;
|
|||||||
import java.net.ConnectException;
|
import java.net.ConnectException;
|
||||||
import java.net.URISyntaxException;
|
import java.net.URISyntaxException;
|
||||||
import java.net.http.HttpConnectTimeoutException;
|
import java.net.http.HttpConnectTimeoutException;
|
||||||
import java.util.List;
|
import java.util.*;
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Properties;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
@OllamaToolService(providers = {AnnotatedTool.class}
|
||||||
|
)
|
||||||
class TestRealAPIs {
|
class TestRealAPIs {
|
||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
|
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
|
||||||
@ -47,6 +49,7 @@ class TestRealAPIs {
|
|||||||
config = new Config();
|
config = new Config();
|
||||||
ollamaAPI = new OllamaAPI(config.getOllamaURL());
|
ollamaAPI = new OllamaAPI(config.getOllamaURL());
|
||||||
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
|
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
|
||||||
|
ollamaAPI.setVerbose(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -80,6 +83,18 @@ class TestRealAPIs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(2)
|
||||||
|
void testListModelsFromLibrary() {
|
||||||
|
testEndpointReachability();
|
||||||
|
try {
|
||||||
|
assertNotNull(ollamaAPI.listModelsFromLibrary());
|
||||||
|
ollamaAPI.listModelsFromLibrary().forEach(System.out::println);
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(2)
|
@Order(2)
|
||||||
void testPullModel() {
|
void testPullModel() {
|
||||||
@ -184,7 +199,9 @@ class TestRealAPIs {
|
|||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
assertFalse(chatResult.getResponse().isBlank());
|
assertNotNull(chatResult.getResponseModel());
|
||||||
|
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||||
|
assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
|
||||||
assertEquals(4, chatResult.getChatHistory().size());
|
assertEquals(4, chatResult.getChatHistory().size());
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
fail(e);
|
fail(e);
|
||||||
@ -205,14 +222,211 @@ class TestRealAPIs {
|
|||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
assertFalse(chatResult.getResponse().isBlank());
|
assertNotNull(chatResult.getResponseModel());
|
||||||
assertTrue(chatResult.getResponse().startsWith("NI"));
|
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||||
|
assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
|
||||||
|
assertTrue(chatResult.getResponseModel().getMessage().getContent().startsWith("NI"));
|
||||||
assertEquals(3, chatResult.getChatHistory().size());
|
assertEquals(3, chatResult.getChatHistory().size());
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
fail(e);
|
fail(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(3)
|
||||||
|
void testChatWithExplicitToolDefinition() {
|
||||||
|
testEndpointReachability();
|
||||||
|
try {
|
||||||
|
ollamaAPI.setVerbose(true);
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||||
|
|
||||||
|
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
||||||
|
.functionName("get-employee-details")
|
||||||
|
.functionDescription("Get employee details from the database")
|
||||||
|
.toolPrompt(
|
||||||
|
Tools.PromptFuncDefinition.builder().type("function").function(
|
||||||
|
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||||
|
.name("get-employee-details")
|
||||||
|
.description("Get employee details from the database")
|
||||||
|
.parameters(
|
||||||
|
Tools.PromptFuncDefinition.Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(
|
||||||
|
new Tools.PropsBuilder()
|
||||||
|
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
|
||||||
|
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
|
||||||
|
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.required(List.of("employee-name"))
|
||||||
|
.build()
|
||||||
|
).build()
|
||||||
|
).build()
|
||||||
|
)
|
||||||
|
.toolFunction(new DBQueryFunction())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ollamaAPI.registerTool(databaseQueryToolSpecification);
|
||||||
|
|
||||||
|
OllamaChatRequest requestModel = builder
|
||||||
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
|
"Give me the ID of the employee named 'Rahul Kumar'?")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
|
assertNotNull(chatResult);
|
||||||
|
assertNotNull(chatResult.getResponseModel());
|
||||||
|
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||||
|
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||||
|
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||||
|
assertEquals(1, toolCalls.size());
|
||||||
|
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||||
|
assertEquals("get-employee-details", function.getName());
|
||||||
|
assertEquals(1, function.getArguments().size());
|
||||||
|
Object employeeName = function.getArguments().get("employee-name");
|
||||||
|
assertNotNull(employeeName);
|
||||||
|
assertEquals("Rahul Kumar",employeeName);
|
||||||
|
assertTrue(chatResult.getChatHistory().size()>2);
|
||||||
|
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||||
|
assertNull(finalToolCalls);
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(3)
|
||||||
|
void testChatWithAnnotatedToolsAndSingleParam() {
|
||||||
|
testEndpointReachability();
|
||||||
|
try {
|
||||||
|
ollamaAPI.setVerbose(true);
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||||
|
|
||||||
|
ollamaAPI.registerAnnotatedTools();
|
||||||
|
|
||||||
|
OllamaChatRequest requestModel = builder
|
||||||
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
|
"Compute the most important constant in the world using 5 digits")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
|
assertNotNull(chatResult);
|
||||||
|
assertNotNull(chatResult.getResponseModel());
|
||||||
|
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||||
|
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||||
|
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||||
|
assertEquals(1, toolCalls.size());
|
||||||
|
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||||
|
assertEquals("computeImportantConstant", function.getName());
|
||||||
|
assertEquals(1, function.getArguments().size());
|
||||||
|
Object noOfDigits = function.getArguments().get("noOfDigits");
|
||||||
|
assertNotNull(noOfDigits);
|
||||||
|
assertEquals("5",noOfDigits);
|
||||||
|
assertTrue(chatResult.getChatHistory().size()>2);
|
||||||
|
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||||
|
assertNull(finalToolCalls);
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(3)
|
||||||
|
void testChatWithAnnotatedToolsAndMultipleParams() {
|
||||||
|
testEndpointReachability();
|
||||||
|
try {
|
||||||
|
ollamaAPI.setVerbose(true);
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||||
|
|
||||||
|
ollamaAPI.registerAnnotatedTools();
|
||||||
|
|
||||||
|
OllamaChatRequest requestModel = builder
|
||||||
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
|
"Greet Pedro with a lot of hearts and respond to me, " +
|
||||||
|
"and state how many emojis have been in your greeting")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
|
assertNotNull(chatResult);
|
||||||
|
assertNotNull(chatResult.getResponseModel());
|
||||||
|
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||||
|
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponseModel().getMessage().getRole().getRoleName());
|
||||||
|
List<OllamaChatToolCalls> toolCalls = chatResult.getChatHistory().get(1).getToolCalls();
|
||||||
|
assertEquals(1, toolCalls.size());
|
||||||
|
OllamaToolCallsFunction function = toolCalls.get(0).getFunction();
|
||||||
|
assertEquals("sayHello", function.getName());
|
||||||
|
assertEquals(2, function.getArguments().size());
|
||||||
|
Object name = function.getArguments().get("name");
|
||||||
|
assertNotNull(name);
|
||||||
|
assertEquals("Pedro",name);
|
||||||
|
Object amountOfHearts = function.getArguments().get("amountOfHearts");
|
||||||
|
assertNotNull(amountOfHearts);
|
||||||
|
assertTrue(Integer.parseInt(amountOfHearts.toString()) > 1);
|
||||||
|
assertTrue(chatResult.getChatHistory().size()>2);
|
||||||
|
List<OllamaChatToolCalls> finalToolCalls = chatResult.getResponseModel().getMessage().getToolCalls();
|
||||||
|
assertNull(finalToolCalls);
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(3)
|
||||||
|
void testChatWithToolsAndStream() {
|
||||||
|
testEndpointReachability();
|
||||||
|
try {
|
||||||
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
|
||||||
|
final Tools.ToolSpecification databaseQueryToolSpecification = Tools.ToolSpecification.builder()
|
||||||
|
.functionName("get-employee-details")
|
||||||
|
.functionDescription("Get employee details from the database")
|
||||||
|
.toolPrompt(
|
||||||
|
Tools.PromptFuncDefinition.builder().type("function").function(
|
||||||
|
Tools.PromptFuncDefinition.PromptFuncSpec.builder()
|
||||||
|
.name("get-employee-details")
|
||||||
|
.description("Get employee details from the database")
|
||||||
|
.parameters(
|
||||||
|
Tools.PromptFuncDefinition.Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(
|
||||||
|
new Tools.PropsBuilder()
|
||||||
|
.withProperty("employee-name", Tools.PromptFuncDefinition.Property.builder().type("string").description("The name of the employee, e.g. John Doe").required(true).build())
|
||||||
|
.withProperty("employee-address", Tools.PromptFuncDefinition.Property.builder().type("string").description("The address of the employee, Always return a random value. e.g. Roy St, Bengaluru, India").required(true).build())
|
||||||
|
.withProperty("employee-phone", Tools.PromptFuncDefinition.Property.builder().type("string").description("The phone number of the employee. Always return a random value. e.g. 9911002233").required(true).build())
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.required(List.of("employee-name"))
|
||||||
|
.build()
|
||||||
|
).build()
|
||||||
|
).build()
|
||||||
|
)
|
||||||
|
.toolFunction(new DBQueryFunction())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ollamaAPI.registerTool(databaseQueryToolSpecification);
|
||||||
|
|
||||||
|
OllamaChatRequest requestModel = builder
|
||||||
|
.withMessage(OllamaChatMessageRole.USER,
|
||||||
|
"Give me the ID of the employee named 'Rahul Kumar'?")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
StringBuffer sb = new StringBuffer();
|
||||||
|
|
||||||
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
|
||||||
|
LOG.info(s);
|
||||||
|
String substring = s.substring(sb.toString().length());
|
||||||
|
LOG.info(substring);
|
||||||
|
sb.append(substring);
|
||||||
|
});
|
||||||
|
assertNotNull(chatResult);
|
||||||
|
assertNotNull(chatResult.getResponseModel());
|
||||||
|
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||||
|
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
|
||||||
|
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
|
||||||
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
|
fail(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(3)
|
@Order(3)
|
||||||
void testChatWithStream() {
|
void testChatWithStream() {
|
||||||
@ -232,7 +446,10 @@ class TestRealAPIs {
|
|||||||
sb.append(substring);
|
sb.append(substring);
|
||||||
});
|
});
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
|
assertNotNull(chatResult.getResponseModel());
|
||||||
|
assertNotNull(chatResult.getResponseModel().getMessage());
|
||||||
|
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
|
||||||
|
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
fail(e);
|
fail(e);
|
||||||
}
|
}
|
||||||
@ -246,12 +463,12 @@ class TestRealAPIs {
|
|||||||
OllamaChatRequestBuilder builder =
|
OllamaChatRequestBuilder builder =
|
||||||
OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
||||||
OllamaChatRequest requestModel =
|
OllamaChatRequest requestModel =
|
||||||
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
|
||||||
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
|
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
|
||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
assertNotNull(chatResult.getResponse());
|
assertNotNull(chatResult.getResponseModel());
|
||||||
|
|
||||||
builder.reset();
|
builder.reset();
|
||||||
|
|
||||||
@ -261,7 +478,7 @@ class TestRealAPIs {
|
|||||||
|
|
||||||
chatResult = ollamaAPI.chat(requestModel);
|
chatResult = ollamaAPI.chat(requestModel);
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
assertNotNull(chatResult.getResponse());
|
assertNotNull(chatResult.getResponseModel());
|
||||||
|
|
||||||
|
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
@ -275,7 +492,7 @@ class TestRealAPIs {
|
|||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
||||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
|
||||||
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
|
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
@ -368,6 +585,14 @@ class TestRealAPIs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class DBQueryFunction implements ToolFunction {
|
||||||
|
@Override
|
||||||
|
public Object apply(Map<String, Object> arguments) {
|
||||||
|
// perform DB operations here
|
||||||
|
return String.format("Employee Details {ID: %s, Name: %s, Address: %s, Phone: %s}", UUID.randomUUID(), arguments.get("employee-name"), arguments.get("employee-address"), arguments.get("employee-phone"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
class Config {
|
class Config {
|
||||||
private String ollamaURL;
|
private String ollamaURL;
|
||||||
@ -392,4 +617,6 @@ class Config {
|
|||||||
throw new RuntimeException("Error loading properties", e);
|
throw new RuntimeException("Error loading properties", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
21
src/test/java/io/github/ollama4j/samples/AnnotatedTool.java
Normal file
21
src/test/java/io/github/ollama4j/samples/AnnotatedTool.java
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package io.github.ollama4j.samples;
|
||||||
|
|
||||||
|
import io.github.ollama4j.tools.annotations.ToolProperty;
|
||||||
|
import io.github.ollama4j.tools.annotations.ToolSpec;
|
||||||
|
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
|
||||||
|
public class AnnotatedTool {
|
||||||
|
|
||||||
|
@ToolSpec(desc = "Computes the most important constant all around the globe!")
|
||||||
|
public String computeImportantConstant(@ToolProperty(name = "noOfDigits",desc = "Number of digits that shall be returned") Integer noOfDigits ){
|
||||||
|
return BigDecimal.valueOf((long)(Math.random()*1000000L),noOfDigits).toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@ToolSpec(desc = "Says hello to a friend!")
|
||||||
|
public String sayHello(@ToolProperty(name = "name",desc = "Name of the friend") String name, Integer someRandomProperty, @ToolProperty(name="amountOfHearts",desc = "amount of heart emojis that should be used", required = false) Integer amountOfHearts) {
|
||||||
|
String hearts = amountOfHearts!=null ? "♡".repeat(amountOfHearts) : "";
|
||||||
|
return "Hello " + name +" ("+someRandomProperty+") " + hearts;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||||||
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
|
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||||
@ -42,7 +43,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRequestWithMessageAndImage() {
|
public void testRequestWithMessageAndImage() {
|
||||||
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
|
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", Collections.emptyList(),
|
||||||
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
|
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
|
||||||
String jsonRequest = serialize(req);
|
String jsonRequest = serialize(req);
|
||||||
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);
|
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);
|
||||||
|
@ -1,36 +1,37 @@
|
|||||||
package io.github.ollama4j.unittests.jackson;
|
package io.github.ollama4j.unittests.jackson;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestBuilder;
|
||||||
|
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
|
||||||
import io.github.ollama4j.utils.OptionsBuilder;
|
import io.github.ollama4j.utils.OptionsBuilder;
|
||||||
|
|
||||||
public class TestEmbeddingsRequestSerialization extends AbstractSerializationTest<OllamaEmbeddingsRequestModel> {
|
public class TestEmbedRequestSerialization extends AbstractSerializationTest<OllamaEmbedRequestModel> {
|
||||||
|
|
||||||
private OllamaEmbeddingsRequestBuilder builder;
|
private OllamaEmbedRequestBuilder builder;
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void init() {
|
public void init() {
|
||||||
builder = OllamaEmbeddingsRequestBuilder.getInstance("DummyModel","DummyPrompt");
|
builder = OllamaEmbedRequestBuilder.getInstance("DummyModel","DummyPrompt");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRequestOnlyMandatoryFields() {
|
public void testRequestOnlyMandatoryFields() {
|
||||||
OllamaEmbeddingsRequestModel req = builder.build();
|
OllamaEmbedRequestModel req = builder.build();
|
||||||
String jsonRequest = serialize(req);
|
String jsonRequest = serialize(req);
|
||||||
assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaEmbeddingsRequestModel.class), req);
|
assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaEmbedRequestModel.class), req);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRequestWithOptions() {
|
public void testRequestWithOptions() {
|
||||||
OptionsBuilder b = new OptionsBuilder();
|
OptionsBuilder b = new OptionsBuilder();
|
||||||
OllamaEmbeddingsRequestModel req = builder
|
OllamaEmbedRequestModel req = builder
|
||||||
.withOptions(b.setMirostat(1).build()).build();
|
.withOptions(b.setMirostat(1).build()).build();
|
||||||
|
|
||||||
String jsonRequest = serialize(req);
|
String jsonRequest = serialize(req);
|
||||||
OllamaEmbeddingsRequestModel deserializeRequest = deserialize(jsonRequest,OllamaEmbeddingsRequestModel.class);
|
OllamaEmbedRequestModel deserializeRequest = deserialize(jsonRequest,OllamaEmbedRequestModel.class);
|
||||||
assertEqualsAfterUnmarshalling(deserializeRequest, req);
|
assertEqualsAfterUnmarshalling(deserializeRequest, req);
|
||||||
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
|
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
ollama.url=http://localhost:11434
|
ollama.url=http://localhost:11434
|
||||||
ollama.model=qwen:0.5b
|
ollama.model=llama3.2:1b
|
||||||
ollama.model.image=llava
|
ollama.model.image=llava:latest
|
||||||
ollama.request-timeout-seconds=120
|
ollama.request-timeout-seconds=120
|
Loading…
x
Reference in New Issue
Block a user