Added support for tools/function calling - specifically for Mistral's latest model.

This commit is contained in:
Amith Koujalgi 2024-07-12 17:06:41 +05:30
parent 8ef6fac28e
commit 91ee6cb4c1
15 changed files with 1006 additions and 490 deletions

View File

@ -1,5 +1,5 @@
--- ---
sidebar_position: 2 sidebar_position: 3
--- ---
# Generate - Async # Generate - Async

View File

@ -1,5 +1,5 @@
--- ---
sidebar_position: 3 sidebar_position: 4
--- ---
# Generate - With Image Files # Generate - With Image Files

View File

@ -1,5 +1,5 @@
--- ---
sidebar_position: 4 sidebar_position: 5
--- ---
# Generate - With Image URLs # Generate - With Image URLs

View File

@ -0,0 +1,271 @@
---
sidebar_position: 2
---
# Generate - With Tools
This API lets you perform [function calling](https://docs.mistral.ai/capabilities/function_calling/) using LLMs in a
synchronous way.
This API correlates to
the [generate](https://github.com/ollama/ollama/blob/main/docs/api.md#request-raw-mode) API with `raw` mode.
:::note
This is an only an experimental implementation and has a very basic design.
Currently, built and tested for [Mistral's latest model](https://ollama.com/library/mistral) only. We could redesign
this
in the future if tooling is supported for more models with a generic interaction standard from Ollama.
:::
### Function Calling/Tools
Assume you want to call a method in your code based on the response generated from the model.
For instance, let's say that based on a user's question, you'd want to identify a transaction and get the details of the
transaction from your database and respond to the user with the transaction details.
You could do that with ease with the `function calling` capabilities of the models by registering your `tools`.
### Create Functions
This function takes the arguments `location` and `fuelType` and performs an operation with these arguments and returns a
value.
```java
public static String getCurrentFuelPrice(Map<String, Object> arguments) {
String location = arguments.get("location").toString();
String fuelType = arguments.get("fuelType").toString();
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
```
This function takes the argument `city` and performs an operation with the argument and returns a
value.
```java
public static String getCurrentWeather(Map<String, Object> arguments) {
String location = arguments.get("city").toString();
return "Currently " + location + "'s weather is nice.";
}
```
### Define Tool Specifications
Lets define a sample tool specification called **Fuel Price Tool** for getting the current fuel price.
- Specify the function `name`, `description`, and `required` properties (`location` and `fuelType`).
- Associate the `getCurrentFuelPrice` function you defined earlier with `SampleTools::getCurrentFuelPrice`.
```java
MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-fuel-price")
.functionDesc("Get current fuel price")
.props(
new MistralTools.PropsBuilder()
.withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentFuelPrice)
.build();
```
Lets also define a sample tool specification called **Weather Tool** for getting the current weather.
- Specify the function `name`, `description`, and `required` property (`city`).
- Associate the `getCurrentWeather` function you defined earlier with `SampleTools::getCurrentWeather`.
```java
MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-weather")
.functionDesc("Get current weather")
.props(
new MistralTools.PropsBuilder()
.withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentWeather)
.build();
```
### Register the Tools
Register the defined tools (`fuel price` and `weather`) with the OllamaAPI.
```shell
ollamaAPI.registerTool(fuelPriceToolSpecification);
ollamaAPI.registerTool(weatherToolSpecification);
```
### Create prompt with Tools
`Prompt 1`: Create a prompt asking for the petrol price in Bengaluru using the defined fuel price and weather tools.
```shell
String prompt1 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the petrol price in Bengaluru?")
.build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt1, false, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
}
```
Now, fire away your question to the model.
You will get a response similar to:
::::tip[LLM Response]
[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
::::
`Prompt 2`: Create a prompt asking for the current weather in Bengaluru using the same tools.
```shell
String prompt2 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the current weather in Bengaluru?")
.build();
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt2, false, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
}
```
Again, fire away your question to the model.
You will get a response similar to:
::::tip[LLM Response]
[Response from tool 'current-weather']: Currently Bengaluru's weather is nice
::::
### Full Example
```java
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.tools.ToolDef;
import io.github.amithkoujalgi.ollama4j.core.tools.MistralTools;
import io.github.amithkoujalgi.ollama4j.core.tools.OllamaToolsResult;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
public class FunctionCallingWithMistral {
public static void main(String[] args) throws Exception {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.setRequestTimeoutSeconds(60);
String model = "mistral";
MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-fuel-price")
.functionDesc("Get current fuel price")
.props(
new MistralTools.PropsBuilder()
.withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentFuelPrice)
.build();
MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder()
.functionName("current-weather")
.functionDesc("Get current weather")
.props(
new MistralTools.PropsBuilder()
.withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build())
.build()
)
.toolDefinition(SampleTools::getCurrentWeather)
.build();
ollamaAPI.registerTool(fuelPriceToolSpecification);
ollamaAPI.registerTool(weatherToolSpecification);
String prompt1 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the petrol price in Bengaluru?")
.build();
String prompt2 = new MistralTools.PromptBuilder()
.withToolSpecification(fuelPriceToolSpecification)
.withToolSpecification(weatherToolSpecification)
.withPrompt("What is the current weather in Bengaluru?")
.build();
ask(ollamaAPI, model, prompt1);
ask(ollamaAPI, model, prompt2);
}
public static void ask(OllamaAPI ollamaAPI, String model, String prompt) throws OllamaBaseException, IOException, InterruptedException {
OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt, false, new OptionsBuilder().build());
for (Map.Entry<ToolDef, Object> r : toolsResult.getToolResults().entrySet()) {
System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString());
}
}
}
class SampleTools {
public static String getCurrentFuelPrice(Map<String, Object> arguments) {
String location = arguments.get("location").toString();
String fuelType = arguments.get("fuelType").toString();
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
public static String getCurrentWeather(Map<String, Object> arguments) {
String location = arguments.get("city").toString();
return "Currently " + location + "'s weather is nice.";
}
}
```
Run this full example and you will get a response similar to:
::::tip[LLM Response]
[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L
[Response from tool 'current-weather']: Currently Bengaluru's weather is nice
::::
### Room for improvement
Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool
registration. For example:
```java
@ToolSpec(name = "current-fuel-price", desc = "Get current fuel price")
public String getCurrentFuelPrice(Map<String, Object> arguments) {
String location = arguments.get("location").toString();
String fuelType = arguments.get("fuelType").toString();
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
```
Instead of passing a map of args `Map<String, Object> arguments` to the tool functions, we could support passing
specific args separately with their data types. For example:
```shell
public String getCurrentFuelPrice(String location, String fuelType) {
return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
}
```
Updating async/chat APIs with support for tool-based generation.

View File

@ -1,5 +1,5 @@
--- ---
sidebar_position: 5 sidebar_position: 6
--- ---
# Prompt Builder # Prompt Builder

View File

@ -1,5 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<groupId>io.github.amithkoujalgi</groupId> <groupId>io.github.amithkoujalgi</groupId>

View File

@ -10,6 +10,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingRe
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.request.*; import io.github.amithkoujalgi.ollama4j.core.models.request.*;
import io.github.amithkoujalgi.ollama4j.core.tools.*;
import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -25,9 +26,7 @@ import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Files; import java.nio.file.Files;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.*;
import java.util.Base64;
import java.util.List;
/** /**
* The base Ollama API class. * The base Ollama API class.
@ -339,6 +338,7 @@ public class OllamaAPI {
} }
} }
/** /**
* Generate response for a question to a model running on Ollama server. This is a sync/blocking * Generate response for a question to a model running on Ollama server. This is a sync/blocking
* call. * call.
@ -351,9 +351,10 @@ public class OllamaAPI {
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false. * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
* @return OllamaResult that includes response text and time taken for response * @return OllamaResult that includes response text and time taken for response
*/ */
public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler) public OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
} }
@ -361,13 +362,37 @@ public class OllamaAPI {
/** /**
* Convenience method to call Ollama API without streaming responses. * Convenience method to call Ollama API without streaming responses.
* <p> * <p>
* Uses {@link #generate(String, String, Options, OllamaStreamHandler)} * Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
*
* @param model Model to use
* @param prompt Prompt text
* @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context.
* @param options Additional Options
* @return OllamaResult
*/ */
public OllamaResult generate(String model, String prompt, Options options) public OllamaResult generate(String model, String prompt, boolean raw, Options options)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
return generate(model, prompt, options, null); return generate(model, prompt, raw, options, null);
} }
public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options)
throws OllamaBaseException, IOException, InterruptedException {
OllamaToolsResult toolResult = new OllamaToolsResult();
Map<ToolDef, Object> toolResults = new HashMap<>();
OllamaResult result = generate(model, prompt, raw, options, null);
toolResult.setModelResult(result);
List<ToolDef> toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class));
for (ToolDef toolDef : toolDefs) {
toolResults.put(toolDef, invokeTool(toolDef));
}
toolResult.setToolResults(toolResults);
return toolResult;
}
/** /**
* Generate response for a question to a model running on Ollama server and get a callback handle * Generate response for a question to a model running on Ollama server and get a callback handle
* that can be used to check for status and get the response from the model later. This would be * that can be used to check for status and get the response from the model later. This would be
@ -377,9 +402,9 @@ public class OllamaAPI {
* @param prompt the prompt/question text * @param prompt the prompt/question text
* @return the ollama async result callback handle * @return the ollama async result callback handle
*/ */
public OllamaAsyncResultCallback generateAsync(String model, String prompt) { public OllamaAsyncResultCallback generateAsync(String model, String prompt, boolean raw) {
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
ollamaRequestModel.setRaw(raw);
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback = OllamaAsyncResultCallback ollamaAsyncResultCallback =
new OllamaAsyncResultCallback( new OllamaAsyncResultCallback(
@ -576,4 +601,24 @@ public class OllamaAPI {
private boolean isBasicAuthCredentialsSet() { private boolean isBasicAuthCredentialsSet() {
return basicAuth != null; return basicAuth != null;
} }
public void registerTool(MistralTools.ToolSpecification toolSpecification) {
ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
}
private Object invokeTool(ToolDef toolDef) {
try {
String methodName = toolDef.getName();
Map<String, Object> arguments = toolDef.getArguments();
DynamicFunction function = ToolRegistry.getFunction(methodName);
if (function == null) {
throw new IllegalArgumentException("No such tool: " + methodName);
}
return function.apply(arguments);
} catch (Exception e) {
e.printStackTrace();
return "Error calling tool: " + e.getMessage();
}
}
} }

View File

@ -1,9 +1,5 @@
package io.github.amithkoujalgi.ollama4j.core.models.request; package io.github.amithkoujalgi.ollama4j.core.models.request;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler; import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
@ -13,8 +9,12 @@ import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRespo
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ import java.io.IOException;
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
@ -31,24 +31,22 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
@Override @Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
try { try {
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
responseBuffer.append(ollamaResponseModel.getResponse()); responseBuffer.append(ollamaResponseModel.getResponse());
if(streamObserver != null) { if (streamObserver != null) {
streamObserver.notify(ollamaResponseModel); streamObserver.notify(ollamaResponseModel);
} }
return ollamaResponseModel.isDone(); return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!",e); LOG.error("Error parsing the Ollama chat response!", e);
return true; return true;
} }
} }
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException { throws OllamaBaseException, IOException, InterruptedException {
streamObserver = new OllamaGenerateStreamObserver(streamHandler); streamObserver = new OllamaGenerateStreamObserver(streamHandler);
return super.callSync(body); return super.callSync(body);
} }
} }

View File

@ -0,0 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import java.util.Map;
@FunctionalInterface
public interface DynamicFunction {
Object apply(Map<String, Object> arguments);
}

View File

@ -0,0 +1,139 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import lombok.Builder;
import lombok.Data;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class MistralTools {
@Data
@Builder
public static class ToolSpecification {
private String functionName;
private String functionDesc;
private Map<String, PromptFuncDefinition.Property> props;
private DynamicFunction toolDefinition;
}
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public static class PromptFuncDefinition {
private String type;
private PromptFuncSpec function;
@Data
public static class PromptFuncSpec {
private String name;
private String description;
private Parameters parameters;
}
@Data
public static class Parameters {
private String type;
private Map<String, Property> properties;
private List<String> required;
}
@Data
@Builder
public static class Property {
private String type;
private String description;
@JsonProperty("enum")
@JsonInclude(JsonInclude.Include.NON_NULL)
private List<String> enumValues;
@JsonIgnore
private boolean required;
}
}
public static class PropsBuilder {
private final Map<String, PromptFuncDefinition.Property> props = new HashMap<>();
public PropsBuilder withProperty(String key, PromptFuncDefinition.Property property) {
props.put(key, property);
return this;
}
public Map<String, PromptFuncDefinition.Property> build() {
return props;
}
}
public static class PromptBuilder {
private final List<PromptFuncDefinition> tools = new ArrayList<>();
private String promptText;
public String build() throws JsonProcessingException {
return "[AVAILABLE_TOOLS] " + Utils.getObjectMapper().writeValueAsString(tools) + "[/AVAILABLE_TOOLS][INST] " + promptText + " [/INST]";
}
public PromptBuilder withPrompt(String prompt) throws JsonProcessingException {
promptText = prompt;
return this;
}
public PromptBuilder withToolSpecification(ToolSpecification spec) {
PromptFuncDefinition def = new PromptFuncDefinition();
def.setType("function");
PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
functionDetail.setName(spec.getFunctionName());
functionDetail.setDescription(spec.getFunctionDesc());
PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
parameters.setType("object");
parameters.setProperties(spec.getProps());
List<String> requiredValues = new ArrayList<>();
for (Map.Entry<String, PromptFuncDefinition.Property> p : spec.getProps().entrySet()) {
if (p.getValue().isRequired()) {
requiredValues.add(p.getKey());
}
}
parameters.setRequired(requiredValues);
functionDetail.setParameters(parameters);
def.setFunction(functionDetail);
tools.add(def);
return this;
}
//
// public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map<String, PromptFuncDefinition.Property> props) {
// PromptFuncDefinition def = new PromptFuncDefinition();
// def.setType("function");
//
// PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
// functionDetail.setName(functionName);
// functionDetail.setDescription(functionDesc);
//
// PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
// parameters.setType("object");
// parameters.setProperties(props);
//
// List<String> requiredValues = new ArrayList<>();
// for (Map.Entry<String, PromptFuncDefinition.Property> p : props.entrySet()) {
// if (p.getValue().isRequired()) {
// requiredValues.add(p.getKey());
// }
// }
// parameters.setRequired(requiredValues);
// functionDetail.setParameters(parameters);
// def.setFunction(functionDetail);
//
// tools.add(def);
// return this;
// }
}
}

View File

@ -0,0 +1,16 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Map;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class OllamaToolsResult {
private OllamaResult modelResult;
private Map<ToolDef, Object> toolResults;
}

View File

@ -0,0 +1,18 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Map;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ToolDef {
private String name;
private Map<String, Object> arguments;
}

View File

@ -0,0 +1,17 @@
package io.github.amithkoujalgi.ollama4j.core.tools;
import java.util.HashMap;
import java.util.Map;
public class ToolRegistry {
private static final Map<String, DynamicFunction> functionMap = new HashMap<>();
public static DynamicFunction getFunction(String name) {
return functionMap.get(name);
}
public static void addFunction(String name, DynamicFunction function) {
functionMap.put(name, function);
}
}

View File

@ -1,7 +1,5 @@
package io.github.amithkoujalgi.ollama4j.integrationtests; package io.github.amithkoujalgi.ollama4j.integrationtests;
import static org.junit.jupiter.api.Assertions.*;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
@ -10,9 +8,16 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import lombok.Data;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -22,372 +27,369 @@ import java.net.http.HttpConnectTimeoutException;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Properties; import java.util.Properties;
import lombok.Data;
import org.junit.jupiter.api.BeforeEach; import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class TestRealAPIs { class TestRealAPIs {
private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class); private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
OllamaAPI ollamaAPI; OllamaAPI ollamaAPI;
Config config; Config config;
private File getImageFileFromClasspath(String fileName) { private File getImageFileFromClasspath(String fileName) {
ClassLoader classLoader = getClass().getClassLoader(); ClassLoader classLoader = getClass().getClassLoader();
return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
}
@BeforeEach
void setUp() {
config = new Config();
ollamaAPI = new OllamaAPI(config.getOllamaURL());
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
}
@Test
@Order(1)
void testWrongEndpoint() {
OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
assertThrows(ConnectException.class, ollamaAPI::listModels);
}
@Test
@Order(1)
void testEndpointReachability() {
try {
assertNotNull(ollamaAPI.listModels());
} catch (HttpConnectTimeoutException e) {
fail(e.getMessage());
} catch (Exception e) {
fail(e);
} }
}
@Test @BeforeEach
@Order(2) void setUp() {
void testListModels() { config = new Config();
testEndpointReachability(); ollamaAPI = new OllamaAPI(config.getOllamaURL());
try { ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
assertNotNull(ollamaAPI.listModels());
ollamaAPI.listModels().forEach(System.out::println);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
} }
}
@Test @Test
@Order(2) @Order(1)
void testPullModel() { void testWrongEndpoint() {
testEndpointReachability(); OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
try { assertThrows(ConnectException.class, ollamaAPI::listModels);
ollamaAPI.pullModel(config.getModel());
boolean found =
ollamaAPI.listModels().stream()
.anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
assertTrue(found);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
} }
}
@Test @Test
@Order(3) @Order(1)
void testListDtails() { void testEndpointReachability() {
testEndpointReachability(); try {
try { assertNotNull(ollamaAPI.listModels());
ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel()); } catch (HttpConnectTimeoutException e) {
assertNotNull(modelDetails); fail(e.getMessage());
System.out.println(modelDetails); } catch (Exception e) {
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { fail(e);
fail(e); }
} }
}
@Test @Test
@Order(3) @Order(2)
void testAskModelWithDefaultOptions() { void testListModels() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaResult result = assertNotNull(ollamaAPI.listModels());
ollamaAPI.generate( ollamaAPI.listModels().forEach(System.out::println);
config.getModel(), } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
"What is the capital of France? And what's France's connection with Mona Lisa?", fail(e);
new OptionsBuilder().build()); }
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
} }
}
@Test @Test
@Order(3) @Order(2)
void testAskModelWithDefaultOptionsStreamed() { void testPullModel() {
testEndpointReachability(); testEndpointReachability();
try { try {
ollamaAPI.pullModel(config.getModel());
StringBuffer sb = new StringBuffer(""); boolean found =
ollamaAPI.listModels().stream()
OllamaResult result = ollamaAPI.generate(config.getModel(), .anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
"What is the capital of France? And what's France's connection with Mona Lisa?", assertTrue(found);
new OptionsBuilder().build(), (s) -> { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
LOG.info(s); fail(e);
String substring = s.substring(sb.toString().length(), s.length()); }
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
} }
}
@Test @Test
@Order(3) @Order(3)
void testAskModelWithOptions() { void testListDtails() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaResult result = ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel());
ollamaAPI.generate( assertNotNull(modelDetails);
config.getModel(), System.out.println(modelDetails);
"What is the capital of France? And what's France's connection with Mona Lisa?", } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
new OptionsBuilder().setTemperature(0.9f).build()); fail(e);
assertNotNull(result); }
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
} }
}
@Test @Test
@Order(3) @Order(3)
void testChat() { void testAskModelWithDefaultOptions() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); OllamaResult result =
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?") ollamaAPI.generate(
.withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!") config.getModel(),
.withMessage(OllamaChatMessageRole.USER,"And what is the second larges city?") "What is the capital of France? And what's France's connection with Mona Lisa?",
.build(); false,
new OptionsBuilder().build());
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); assertNotNull(result);
assertNotNull(chatResult); assertNotNull(result.getResponse());
assertFalse(chatResult.getResponse().isBlank()); assertFalse(result.getResponse().isEmpty());
assertEquals(4,chatResult.getChatHistory().size()); } catch (IOException | OllamaBaseException | InterruptedException e) {
} catch (IOException | OllamaBaseException | InterruptedException e) { fail(e);
fail(e); }
} }
}
@Test @Test
@Order(3) @Order(3)
void testChatWithSystemPrompt() { void testAskModelWithDefaultOptionsStreamed() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); StringBuffer sb = new StringBuffer("");
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, OllamaResult result = ollamaAPI.generate(config.getModel(),
"You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!") "What is the capital of France? And what's France's connection with Mona Lisa?",
.withMessage(OllamaChatMessageRole.USER, false,
"What is the capital of France? And what's France's connection with Mona Lisa?") new OptionsBuilder().build(), (s) -> {
.build(); LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); assertNotNull(result);
assertNotNull(chatResult); assertNotNull(result.getResponse());
assertFalse(chatResult.getResponse().isBlank()); assertFalse(result.getResponse().isEmpty());
assertTrue(chatResult.getResponse().startsWith("NI")); assertEquals(sb.toString().trim(), result.getResponse().trim());
assertEquals(3, chatResult.getChatHistory().size()); } catch (IOException | OllamaBaseException | InterruptedException e) {
} catch (IOException | OllamaBaseException | InterruptedException e) { fail(e);
fail(e); }
} }
}
@Test @Test
@Order(3) @Order(3)
void testChatWithStream() { void testAskModelWithOptions() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); OllamaResult result =
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, ollamaAPI.generate(
"What is the capital of France? And what's France's connection with Mona Lisa?") config.getModel(),
.build(); "What is the capital of France? And what's France's connection with Mona Lisa?",
true,
StringBuffer sb = new StringBuffer(""); new OptionsBuilder().setTemperature(0.9f).build());
assertNotNull(result);
OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> { assertNotNull(result.getResponse());
LOG.info(s); assertFalse(result.getResponse().isEmpty());
String substring = s.substring(sb.toString().length(), s.length()); } catch (IOException | OllamaBaseException | InterruptedException e) {
LOG.info(substring); fail(e);
sb.append(substring); }
});
assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
} }
}
@Test @Test
@Order(3) @Order(3)
void testChatWithImageFromFileWithHistoryRecognition() { void testChat() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestBuilder.getInstance(config.getImageModel()); OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
OllamaChatRequestModel requestModel = .withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", .withMessage(OllamaChatMessageRole.USER, "And what is the second larges city?")
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); .build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponse()); assertFalse(chatResult.getResponse().isBlank());
assertEquals(4, chatResult.getChatHistory().size());
builder.reset(); } catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
requestModel = }
builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult);
assertNotNull(chatResult.getResponse());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
} }
}
@Test @Test
@Order(3) @Order(3)
void testChatWithImageFromURL() { void testChatWithSystemPrompt() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel()); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
.build(); .withMessage(OllamaChatMessageRole.USER,
"What is the capital of France? And what's France's connection with Mona Lisa?")
.build();
OllamaChatResult chatResult = ollamaAPI.chat(requestModel); OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
assertNotNull(chatResult); assertNotNull(chatResult);
} catch (IOException | OllamaBaseException | InterruptedException e) { assertFalse(chatResult.getResponse().isBlank());
fail(e); assertTrue(chatResult.getResponse().startsWith("NI"));
assertEquals(3, chatResult.getChatHistory().size());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
} }
}
@Test @Test
@Order(3) @Order(3)
void testAskModelWithOptionsAndImageFiles() { void testChatWithStream() {
testEndpointReachability(); testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); try {
try { OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
OllamaResult result = OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
ollamaAPI.generateWithImageFiles( "What is the capital of France? And what's France's connection with Mona Lisa?")
config.getImageModel(), .build();
"What is in this image?",
List.of(imageFile), StringBuffer sb = new StringBuffer("");
new OptionsBuilder().build());
assertNotNull(result); OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
assertNotNull(result.getResponse()); LOG.info(s);
assertFalse(result.getResponse().isEmpty()); String substring = s.substring(sb.toString().length(), s.length());
} catch (IOException | OllamaBaseException | InterruptedException e) { LOG.info(substring);
fail(e); sb.append(substring);
});
assertNotNull(chatResult);
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
} }
}
@Test @Test
@Order(3) @Order(3)
void testAskModelWithOptionsAndImageFilesStreamed() { void testChatWithImageFromFileWithHistoryRecognition() {
testEndpointReachability(); testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); try {
try { OllamaChatRequestBuilder builder =
StringBuffer sb = new StringBuffer(""); OllamaChatRequestBuilder.getInstance(config.getImageModel());
OllamaChatRequestModel requestModel =
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(), OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
"What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> { assertNotNull(chatResult);
LOG.info(s); assertNotNull(chatResult.getResponse());
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring); builder.reset();
sb.append(substring);
}); requestModel =
assertNotNull(result); builder.withMessages(chatResult.getChatHistory())
assertNotNull(result.getResponse()); .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim()); chatResult = ollamaAPI.chat(requestModel);
} catch (IOException | OllamaBaseException | InterruptedException e) { assertNotNull(chatResult);
fail(e); assertNotNull(chatResult.getResponse());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
} }
}
@Test @Test
@Order(3) @Order(3)
void testAskModelWithOptionsAndImageURLs() { void testChatWithImageFromURL() {
testEndpointReachability(); testEndpointReachability();
try { try {
OllamaResult result = OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
ollamaAPI.generateWithImageURLs( OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
config.getImageModel(), "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
"What is in this image?", .build();
List.of(
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"), OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
new OptionsBuilder().build()); assertNotNull(chatResult);
assertNotNull(result); } catch (IOException | OllamaBaseException | InterruptedException e) {
assertNotNull(result.getResponse()); fail(e);
assertFalse(result.getResponse().isEmpty()); }
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
} }
}
@Test @Test
@Order(3) @Order(3)
public void testEmbedding() { void testAskModelWithOptionsAndImageFiles() {
testEndpointReachability(); testEndpointReachability();
try { File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder try {
.getInstance(config.getModel(), "What is the capital of France?").build(); OllamaResult result =
ollamaAPI.generateWithImageFiles(
List<Double> embeddings = ollamaAPI.generateEmbeddings(request); config.getImageModel(),
"What is in this image?",
assertNotNull(embeddings); List.of(imageFile),
assertFalse(embeddings.isEmpty()); new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException e) { assertNotNull(result);
fail(e); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageFilesStreamed() {
testEndpointReachability();
File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
try {
StringBuffer sb = new StringBuffer("");
OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
"What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
}
@Test
@Order(3)
void testAskModelWithOptionsAndImageURLs() {
testEndpointReachability();
try {
OllamaResult result =
ollamaAPI.generateWithImageURLs(
config.getImageModel(),
"What is in this image?",
List.of(
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
new OptionsBuilder().build());
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
fail(e);
}
}
@Test
@Order(3)
public void testEmbedding() {
testEndpointReachability();
try {
OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
.getInstance(config.getModel(), "What is the capital of France?").build();
List<Double> embeddings = ollamaAPI.generateEmbeddings(request);
assertNotNull(embeddings);
assertFalse(embeddings.isEmpty());
} catch (IOException | OllamaBaseException | InterruptedException e) {
fail(e);
}
} }
}
} }
@Data @Data
class Config { class Config {
private String ollamaURL; private String ollamaURL;
private String model; private String model;
private String imageModel; private String imageModel;
private int requestTimeoutSeconds; private int requestTimeoutSeconds;
public Config() { public Config() {
Properties properties = new Properties(); Properties properties = new Properties();
try (InputStream input = try (InputStream input =
getClass().getClassLoader().getResourceAsStream("test-config.properties")) { getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
if (input == null) { if (input == null) {
throw new RuntimeException("Sorry, unable to find test-config.properties"); throw new RuntimeException("Sorry, unable to find test-config.properties");
} }
properties.load(input); properties.load(input);
this.ollamaURL = properties.getProperty("ollama.url"); this.ollamaURL = properties.getProperty("ollama.url");
this.model = properties.getProperty("ollama.model"); this.model = properties.getProperty("ollama.model");
this.imageModel = properties.getProperty("ollama.model.image"); this.imageModel = properties.getProperty("ollama.model.image");
this.requestTimeoutSeconds = this.requestTimeoutSeconds =
Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds")); Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException("Error loading properties", e); throw new RuntimeException("Error loading properties", e);
}
} }
}
} }

View File

@ -1,7 +1,5 @@
package io.github.amithkoujalgi.ollama4j.unittests; package io.github.amithkoujalgi.ollama4j.unittests;
import static org.mockito.Mockito.*;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
@ -9,155 +7,158 @@ import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import java.io.IOException; import java.io.IOException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito; import static org.mockito.Mockito.*;
class TestMockedAPIs { class TestMockedAPIs {
@Test @Test
void testPullModel() { void testPullModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
try { try {
doNothing().when(ollamaAPI).pullModel(model); doNothing().when(ollamaAPI).pullModel(model);
ollamaAPI.pullModel(model); ollamaAPI.pullModel(model);
verify(ollamaAPI, times(1)).pullModel(model); verify(ollamaAPI, times(1)).pullModel(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testListModels() { void testListModels() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
try { try {
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>()); when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
ollamaAPI.listModels(); ollamaAPI.listModels();
verify(ollamaAPI, times(1)).listModels(); verify(ollamaAPI, times(1)).listModels();
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testCreateModel() { void testCreateModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros."; String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
try { try {
doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath); doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
ollamaAPI.createModelWithModelFileContents(model, modelFilePath); ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath); verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testDeleteModel() { void testDeleteModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
try { try {
doNothing().when(ollamaAPI).deleteModel(model, true); doNothing().when(ollamaAPI).deleteModel(model, true);
ollamaAPI.deleteModel(model, true); ollamaAPI.deleteModel(model, true);
verify(ollamaAPI, times(1)).deleteModel(model, true); verify(ollamaAPI, times(1)).deleteModel(model, true);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testGetModelDetails() { void testGetModelDetails() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
try { try {
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail()); when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
ollamaAPI.getModelDetails(model); ollamaAPI.getModelDetails(model);
verify(ollamaAPI, times(1)).getModelDetails(model); verify(ollamaAPI, times(1)).getModelDetails(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testGenerateEmbeddings() { void testGenerateEmbeddings() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
try { try {
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>()); when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
ollamaAPI.generateEmbeddings(model, prompt); ollamaAPI.generateEmbeddings(model, prompt);
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt); verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testAsk() { void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
OptionsBuilder optionsBuilder = new OptionsBuilder(); OptionsBuilder optionsBuilder = new OptionsBuilder();
try { try {
when(ollamaAPI.generate(model, prompt, optionsBuilder.build())) when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build()))
.thenReturn(new OllamaResult("", 0, 200)); .thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generate(model, prompt, optionsBuilder.build()); ollamaAPI.generate(model, prompt, false, optionsBuilder.build());
verify(ollamaAPI, times(1)).generate(model, prompt, optionsBuilder.build()); verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testAskWithImageFiles() { void testAskWithImageFiles() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
try { try {
when(ollamaAPI.generateWithImageFiles( when(ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build())) model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200)); .thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generateWithImageFiles( ollamaAPI.generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()); model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1)) verify(ollamaAPI, times(1))
.generateWithImageFiles( .generateWithImageFiles(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()); model, prompt, Collections.emptyList(), new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testAskWithImageURLs() { void testAskWithImageURLs() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
try { try {
when(ollamaAPI.generateWithImageURLs( when(ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build())) model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
.thenReturn(new OllamaResult("", 0, 200)); .thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.generateWithImageURLs( ollamaAPI.generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()); model, prompt, Collections.emptyList(), new OptionsBuilder().build());
verify(ollamaAPI, times(1)) verify(ollamaAPI, times(1))
.generateWithImageURLs( .generateWithImageURLs(
model, prompt, Collections.emptyList(), new OptionsBuilder().build()); model, prompt, Collections.emptyList(), new OptionsBuilder().build());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
void testAskAsync() { void testAskAsync() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
when(ollamaAPI.generateAsync(model, prompt)) when(ollamaAPI.generateAsync(model, prompt, false))
.thenReturn(new OllamaAsyncResultCallback(null, null, 3)); .thenReturn(new OllamaAsyncResultCallback(null, null, 3));
ollamaAPI.generateAsync(model, prompt); ollamaAPI.generateAsync(model, prompt, false);
verify(ollamaAPI, times(1)).generateAsync(model, prompt); verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
} }
} }