diff --git a/docs/docs/apis-generate/generate-async.md b/docs/docs/apis-generate/generate-async.md index 7d8cc54..49f556f 100644 --- a/docs/docs/apis-generate/generate-async.md +++ b/docs/docs/apis-generate/generate-async.md @@ -1,5 +1,5 @@ --- -sidebar_position: 2 +sidebar_position: 3 --- # Generate - Async diff --git a/docs/docs/apis-generate/generate-with-image-files.md b/docs/docs/apis-generate/generate-with-image-files.md index 37f4f03..4406981 100644 --- a/docs/docs/apis-generate/generate-with-image-files.md +++ b/docs/docs/apis-generate/generate-with-image-files.md @@ -1,5 +1,5 @@ --- -sidebar_position: 3 +sidebar_position: 4 --- # Generate - With Image Files diff --git a/docs/docs/apis-generate/generate-with-image-urls.md b/docs/docs/apis-generate/generate-with-image-urls.md index 19d6cf1..587e8f0 100644 --- a/docs/docs/apis-generate/generate-with-image-urls.md +++ b/docs/docs/apis-generate/generate-with-image-urls.md @@ -1,5 +1,5 @@ --- -sidebar_position: 4 +sidebar_position: 5 --- # Generate - With Image URLs diff --git a/docs/docs/apis-generate/generate-with-tools.md b/docs/docs/apis-generate/generate-with-tools.md new file mode 100644 index 0000000..0ca142a --- /dev/null +++ b/docs/docs/apis-generate/generate-with-tools.md @@ -0,0 +1,271 @@ +--- +sidebar_position: 2 +--- + +# Generate - With Tools + +This API lets you perform [function calling](https://docs.mistral.ai/capabilities/function_calling/) using LLMs in a +synchronous way. +This API correlates to +the [generate](https://github.com/ollama/ollama/blob/main/docs/api.md#request-raw-mode) API with `raw` mode. + +:::note + +This is an only an experimental implementation and has a very basic design. + +Currently, built and tested for [Mistral's latest model](https://ollama.com/library/mistral) only. We could redesign +this +in the future if tooling is supported for more models with a generic interaction standard from Ollama. + +::: + +### Function Calling/Tools + +Assume you want to call a method in your code based on the response generated from the model. +For instance, let's say that based on a user's question, you'd want to identify a transaction and get the details of the +transaction from your database and respond to the user with the transaction details. + +You could do that with ease with the `function calling` capabilities of the models by registering your `tools`. + +### Create Functions + +This function takes the arguments `location` and `fuelType` and performs an operation with these arguments and returns a +value. + +```java +public static String getCurrentFuelPrice(Map arguments) { + String location = arguments.get("location").toString(); + String fuelType = arguments.get("fuelType").toString(); + return "Current price of " + fuelType + " in " + location + " is Rs.103/L"; +} +``` + +This function takes the argument `city` and performs an operation with the argument and returns a +value. + +```java +public static String getCurrentWeather(Map arguments) { + String location = arguments.get("city").toString(); + return "Currently " + location + "'s weather is nice."; +} +``` + +### Define Tool Specifications + +Lets define a sample tool specification called **Fuel Price Tool** for getting the current fuel price. + +- Specify the function `name`, `description`, and `required` properties (`location` and `fuelType`). +- Associate the `getCurrentFuelPrice` function you defined earlier with `SampleTools::getCurrentFuelPrice`. + +```java +MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder() + .functionName("current-fuel-price") + .functionDesc("Get current fuel price") + .props( + new MistralTools.PropsBuilder() + .withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) + .withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build()) + .build() + ) + .toolDefinition(SampleTools::getCurrentFuelPrice) + .build(); +``` + +Lets also define a sample tool specification called **Weather Tool** for getting the current weather. + +- Specify the function `name`, `description`, and `required` property (`city`). +- Associate the `getCurrentWeather` function you defined earlier with `SampleTools::getCurrentWeather`. + +```java +MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder() + .functionName("current-weather") + .functionDesc("Get current weather") + .props( + new MistralTools.PropsBuilder() + .withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) + .build() + ) + .toolDefinition(SampleTools::getCurrentWeather) + .build(); +``` + +### Register the Tools + +Register the defined tools (`fuel price` and `weather`) with the OllamaAPI. + +```shell +ollamaAPI.registerTool(fuelPriceToolSpecification); +ollamaAPI.registerTool(weatherToolSpecification); +``` + +### Create prompt with Tools + +`Prompt 1`: Create a prompt asking for the petrol price in Bengaluru using the defined fuel price and weather tools. + +```shell +String prompt1 = new MistralTools.PromptBuilder() + .withToolSpecification(fuelPriceToolSpecification) + .withToolSpecification(weatherToolSpecification) + .withPrompt("What is the petrol price in Bengaluru?") + .build(); +OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt1, false, new OptionsBuilder().build()); +for (Map.Entry r : toolsResult.getToolResults().entrySet()) { + System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString()); +} +``` + +Now, fire away your question to the model. + +You will get a response similar to: + +::::tip[LLM Response] + +[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L + +:::: + +`Prompt 2`: Create a prompt asking for the current weather in Bengaluru using the same tools. + +```shell +String prompt2 = new MistralTools.PromptBuilder() + .withToolSpecification(fuelPriceToolSpecification) + .withToolSpecification(weatherToolSpecification) + .withPrompt("What is the current weather in Bengaluru?") + .build(); +OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt2, false, new OptionsBuilder().build()); +for (Map.Entry r : toolsResult.getToolResults().entrySet()) { + System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString()); +} +``` + +Again, fire away your question to the model. + +You will get a response similar to: + +::::tip[LLM Response] + +[Response from tool 'current-weather']: Currently Bengaluru's weather is nice +:::: + +### Full Example + +```java + +import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; +import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; +import io.github.amithkoujalgi.ollama4j.core.tools.ToolDef; +import io.github.amithkoujalgi.ollama4j.core.tools.MistralTools; +import io.github.amithkoujalgi.ollama4j.core.tools.OllamaToolsResult; +import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; + +public class FunctionCallingWithMistral { + public static void main(String[] args) throws Exception { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + ollamaAPI.setRequestTimeoutSeconds(60); + + String model = "mistral"; + + + MistralTools.ToolSpecification fuelPriceToolSpecification = MistralTools.ToolSpecification.builder() + .functionName("current-fuel-price") + .functionDesc("Get current fuel price") + .props( + new MistralTools.PropsBuilder() + .withProperty("location", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) + .withProperty("fuelType", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The fuel type.").enumValues(Arrays.asList("petrol", "diesel")).required(true).build()) + .build() + ) + .toolDefinition(SampleTools::getCurrentFuelPrice) + .build(); + + MistralTools.ToolSpecification weatherToolSpecification = MistralTools.ToolSpecification.builder() + .functionName("current-weather") + .functionDesc("Get current weather") + .props( + new MistralTools.PropsBuilder() + .withProperty("city", MistralTools.PromptFuncDefinition.Property.builder().type("string").description("The city, e.g. New Delhi, India").required(true).build()) + .build() + ) + .toolDefinition(SampleTools::getCurrentWeather) + .build(); + + ollamaAPI.registerTool(fuelPriceToolSpecification); + ollamaAPI.registerTool(weatherToolSpecification); + + String prompt1 = new MistralTools.PromptBuilder() + .withToolSpecification(fuelPriceToolSpecification) + .withToolSpecification(weatherToolSpecification) + .withPrompt("What is the petrol price in Bengaluru?") + .build(); + String prompt2 = new MistralTools.PromptBuilder() + .withToolSpecification(fuelPriceToolSpecification) + .withToolSpecification(weatherToolSpecification) + .withPrompt("What is the current weather in Bengaluru?") + .build(); + + ask(ollamaAPI, model, prompt1); + ask(ollamaAPI, model, prompt2); + } + + public static void ask(OllamaAPI ollamaAPI, String model, String prompt) throws OllamaBaseException, IOException, InterruptedException { + OllamaToolsResult toolsResult = ollamaAPI.generateWithTools(model, prompt, false, new OptionsBuilder().build()); + for (Map.Entry r : toolsResult.getToolResults().entrySet()) { + System.out.printf("[Response from tool '%s']: %s%n", r.getKey().getName(), r.getValue().toString()); + } + } +} + +class SampleTools { + public static String getCurrentFuelPrice(Map arguments) { + String location = arguments.get("location").toString(); + String fuelType = arguments.get("fuelType").toString(); + return "Current price of " + fuelType + " in " + location + " is Rs.103/L"; + } + + public static String getCurrentWeather(Map arguments) { + String location = arguments.get("city").toString(); + return "Currently " + location + "'s weather is nice."; + } +} + +``` + +Run this full example and you will get a response similar to: + +::::tip[LLM Response] + +[Response from tool 'current-fuel-price']: Current price of petrol in Bengaluru is Rs.103/L + +[Response from tool 'current-weather']: Currently Bengaluru's weather is nice +:::: + +### Room for improvement + +Instead of explicitly registering `ollamaAPI.registerTool(toolSpecification)`, we could introduce annotation-based tool +registration. For example: + +```java + +@ToolSpec(name = "current-fuel-price", desc = "Get current fuel price") +public String getCurrentFuelPrice(Map arguments) { + String location = arguments.get("location").toString(); + String fuelType = arguments.get("fuelType").toString(); + return "Current price of " + fuelType + " in " + location + " is Rs.103/L"; +} +``` + +Instead of passing a map of args `Map arguments` to the tool functions, we could support passing +specific args separately with their data types. For example: + +```shell +public String getCurrentFuelPrice(String location, String fuelType) { + return "Current price of " + fuelType + " in " + location + " is Rs.103/L"; +} +``` + +Updating async/chat APIs with support for tool-based generation. \ No newline at end of file diff --git a/docs/docs/apis-generate/prompt-builder.md b/docs/docs/apis-generate/prompt-builder.md index a798808..ffe57d7 100644 --- a/docs/docs/apis-generate/prompt-builder.md +++ b/docs/docs/apis-generate/prompt-builder.md @@ -1,5 +1,5 @@ --- -sidebar_position: 5 +sidebar_position: 6 --- # Prompt Builder diff --git a/pom.xml b/pom.xml index 800ba57..c2f3754 100644 --- a/pom.xml +++ b/pom.xml @@ -1,5 +1,6 @@ - + 4.0.0 io.github.amithkoujalgi diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index 1f22210..80654ae 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -10,6 +10,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingRe import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.request.*; +import io.github.amithkoujalgi.ollama4j.core.tools.*; import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import org.slf4j.Logger; @@ -25,9 +26,7 @@ import java.net.http.HttpResponse; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.time.Duration; -import java.util.ArrayList; -import java.util.Base64; -import java.util.List; +import java.util.*; /** * The base Ollama API class. @@ -339,6 +338,7 @@ public class OllamaAPI { } } + /** * Generate response for a question to a model running on Ollama server. This is a sync/blocking * call. @@ -351,9 +351,10 @@ public class OllamaAPI { * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false. * @return OllamaResult that includes response text and time taken for response */ - public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler) + public OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); + ollamaRequestModel.setRaw(raw); ollamaRequestModel.setOptions(options.getOptionsMap()); return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler); } @@ -361,13 +362,37 @@ public class OllamaAPI { /** * Convenience method to call Ollama API without streaming responses. *

- * Uses {@link #generate(String, String, Options, OllamaStreamHandler)} + * Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)} + * + * @param model Model to use + * @param prompt Prompt text + * @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context. + * @param options Additional Options + * @return OllamaResult */ - public OllamaResult generate(String model, String prompt, Options options) + public OllamaResult generate(String model, String prompt, boolean raw, Options options) throws OllamaBaseException, IOException, InterruptedException { - return generate(model, prompt, options, null); + return generate(model, prompt, raw, options, null); } + + public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options) + throws OllamaBaseException, IOException, InterruptedException { + OllamaToolsResult toolResult = new OllamaToolsResult(); + Map toolResults = new HashMap<>(); + + OllamaResult result = generate(model, prompt, raw, options, null); + toolResult.setModelResult(result); + + List toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class)); + for (ToolDef toolDef : toolDefs) { + toolResults.put(toolDef, invokeTool(toolDef)); + } + toolResult.setToolResults(toolResults); + return toolResult; + } + + /** * Generate response for a question to a model running on Ollama server and get a callback handle * that can be used to check for status and get the response from the model later. This would be @@ -377,9 +402,9 @@ public class OllamaAPI { * @param prompt the prompt/question text * @return the ollama async result callback handle */ - public OllamaAsyncResultCallback generateAsync(String model, String prompt) { + public OllamaAsyncResultCallback generateAsync(String model, String prompt, boolean raw) { OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); - + ollamaRequestModel.setRaw(raw); URI uri = URI.create(this.host + "/api/generate"); OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback( @@ -576,4 +601,24 @@ public class OllamaAPI { private boolean isBasicAuthCredentialsSet() { return basicAuth != null; } + + + public void registerTool(MistralTools.ToolSpecification toolSpecification) { + ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition()); + } + + private Object invokeTool(ToolDef toolDef) { + try { + String methodName = toolDef.getName(); + Map arguments = toolDef.getArguments(); + DynamicFunction function = ToolRegistry.getFunction(methodName); + if (function == null) { + throw new IllegalArgumentException("No such tool: " + methodName); + } + return function.apply(arguments); + } catch (Exception e) { + e.printStackTrace(); + return "Error calling tool: " + e.getMessage(); + } + } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java index fe7fbec..d3d71e4 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java @@ -1,9 +1,5 @@ package io.github.amithkoujalgi.ollama4j.core.models.request; -import java.io.IOException; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.fasterxml.jackson.core.JsonProcessingException; import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; @@ -13,15 +9,19 @@ import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRespo import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ +import java.io.IOException; + +public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class); private OllamaGenerateStreamObserver streamObserver; public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { - super(host, basicAuth, requestTimeoutSeconds, verbose); + super(host, basicAuth, requestTimeoutSeconds, verbose); } @Override @@ -31,24 +31,22 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ @Override protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { - try { - OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); - responseBuffer.append(ollamaResponseModel.getResponse()); - if(streamObserver != null) { - streamObserver.notify(ollamaResponseModel); - } - return ollamaResponseModel.isDone(); - } catch (JsonProcessingException e) { - LOG.error("Error parsing the Ollama chat response!",e); - return true; - } + try { + OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); + responseBuffer.append(ollamaResponseModel.getResponse()); + if (streamObserver != null) { + streamObserver.notify(ollamaResponseModel); + } + return ollamaResponseModel.isDone(); + } catch (JsonProcessingException e) { + LOG.error("Error parsing the Ollama chat response!", e); + return true; + } } public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) - throws OllamaBaseException, IOException, InterruptedException { - streamObserver = new OllamaGenerateStreamObserver(streamHandler); - return super.callSync(body); + throws OllamaBaseException, IOException, InterruptedException { + streamObserver = new OllamaGenerateStreamObserver(streamHandler); + return super.callSync(body); } - - } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/DynamicFunction.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/DynamicFunction.java new file mode 100644 index 0000000..5b8f5e6 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/DynamicFunction.java @@ -0,0 +1,8 @@ +package io.github.amithkoujalgi.ollama4j.core.tools; + +import java.util.Map; + +@FunctionalInterface +public interface DynamicFunction { + Object apply(Map arguments); +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/MistralTools.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/MistralTools.java new file mode 100644 index 0000000..fff8071 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/MistralTools.java @@ -0,0 +1,139 @@ +package io.github.amithkoujalgi.ollama4j.core.tools; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; +import lombok.Builder; +import lombok.Data; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class MistralTools { + @Data + @Builder + public static class ToolSpecification { + private String functionName; + private String functionDesc; + private Map props; + private DynamicFunction toolDefinition; + } + + @Data + @JsonIgnoreProperties(ignoreUnknown = true) + public static class PromptFuncDefinition { + private String type; + private PromptFuncSpec function; + + @Data + public static class PromptFuncSpec { + private String name; + private String description; + private Parameters parameters; + } + + @Data + public static class Parameters { + private String type; + private Map properties; + private List required; + } + + @Data + @Builder + public static class Property { + private String type; + private String description; + @JsonProperty("enum") + @JsonInclude(JsonInclude.Include.NON_NULL) + private List enumValues; + @JsonIgnore + private boolean required; + } + } + + public static class PropsBuilder { + private final Map props = new HashMap<>(); + + public PropsBuilder withProperty(String key, PromptFuncDefinition.Property property) { + props.put(key, property); + return this; + } + + public Map build() { + return props; + } + } + + public static class PromptBuilder { + private final List tools = new ArrayList<>(); + + private String promptText; + + public String build() throws JsonProcessingException { + return "[AVAILABLE_TOOLS] " + Utils.getObjectMapper().writeValueAsString(tools) + "[/AVAILABLE_TOOLS][INST] " + promptText + " [/INST]"; + } + + public PromptBuilder withPrompt(String prompt) throws JsonProcessingException { + promptText = prompt; + return this; + } + + public PromptBuilder withToolSpecification(ToolSpecification spec) { + PromptFuncDefinition def = new PromptFuncDefinition(); + def.setType("function"); + + PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec(); + functionDetail.setName(spec.getFunctionName()); + functionDetail.setDescription(spec.getFunctionDesc()); + + PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters(); + parameters.setType("object"); + parameters.setProperties(spec.getProps()); + + List requiredValues = new ArrayList<>(); + for (Map.Entry p : spec.getProps().entrySet()) { + if (p.getValue().isRequired()) { + requiredValues.add(p.getKey()); + } + } + parameters.setRequired(requiredValues); + functionDetail.setParameters(parameters); + def.setFunction(functionDetail); + + tools.add(def); + return this; + } +// +// public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map props) { +// PromptFuncDefinition def = new PromptFuncDefinition(); +// def.setType("function"); +// +// PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec(); +// functionDetail.setName(functionName); +// functionDetail.setDescription(functionDesc); +// +// PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters(); +// parameters.setType("object"); +// parameters.setProperties(props); +// +// List requiredValues = new ArrayList<>(); +// for (Map.Entry p : props.entrySet()) { +// if (p.getValue().isRequired()) { +// requiredValues.add(p.getKey()); +// } +// } +// parameters.setRequired(requiredValues); +// functionDetail.setParameters(parameters); +// def.setFunction(functionDetail); +// +// tools.add(def); +// return this; +// } + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java new file mode 100644 index 0000000..65ef3ac --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java @@ -0,0 +1,16 @@ +package io.github.amithkoujalgi.ollama4j.core.tools; + +import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.Map; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class OllamaToolsResult { + private OllamaResult modelResult; + private Map toolResults; +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolDef.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolDef.java new file mode 100644 index 0000000..751d186 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolDef.java @@ -0,0 +1,18 @@ +package io.github.amithkoujalgi.ollama4j.core.tools; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.Map; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class ToolDef { + + private String name; + private Map arguments; + +} + diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java new file mode 100644 index 0000000..0004c7f --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java @@ -0,0 +1,17 @@ +package io.github.amithkoujalgi.ollama4j.core.tools; + +import java.util.HashMap; +import java.util.Map; + +public class ToolRegistry { + private static final Map functionMap = new HashMap<>(); + + + public static DynamicFunction getFunction(String name) { + return functionMap.get(name); + } + + public static void addFunction(String name, DynamicFunction function) { + functionMap.put(name, function); + } +} diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java index d822077..58e55a1 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java @@ -1,7 +1,5 @@ package io.github.amithkoujalgi.ollama4j.integrationtests; -import static org.junit.jupiter.api.Assertions.*; - import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; @@ -10,9 +8,16 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; -import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder; +import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; +import lombok.Data; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.io.File; import java.io.IOException; import java.io.InputStream; @@ -22,372 +27,369 @@ import java.net.http.HttpConnectTimeoutException; import java.util.List; import java.util.Objects; import java.util.Properties; -import lombok.Data; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Order; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + +import static org.junit.jupiter.api.Assertions.*; class TestRealAPIs { - private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class); + private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class); - OllamaAPI ollamaAPI; - Config config; + OllamaAPI ollamaAPI; + Config config; - private File getImageFileFromClasspath(String fileName) { - ClassLoader classLoader = getClass().getClassLoader(); - return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); - } - - @BeforeEach - void setUp() { - config = new Config(); - ollamaAPI = new OllamaAPI(config.getOllamaURL()); - ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds()); - } - - @Test - @Order(1) - void testWrongEndpoint() { - OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434"); - assertThrows(ConnectException.class, ollamaAPI::listModels); - } - - @Test - @Order(1) - void testEndpointReachability() { - try { - assertNotNull(ollamaAPI.listModels()); - } catch (HttpConnectTimeoutException e) { - fail(e.getMessage()); - } catch (Exception e) { - fail(e); + private File getImageFileFromClasspath(String fileName) { + ClassLoader classLoader = getClass().getClassLoader(); + return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()); } - } - @Test - @Order(2) - void testListModels() { - testEndpointReachability(); - try { - assertNotNull(ollamaAPI.listModels()); - ollamaAPI.listModels().forEach(System.out::println); - } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - fail(e); + @BeforeEach + void setUp() { + config = new Config(); + ollamaAPI = new OllamaAPI(config.getOllamaURL()); + ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds()); } - } - @Test - @Order(2) - void testPullModel() { - testEndpointReachability(); - try { - ollamaAPI.pullModel(config.getModel()); - boolean found = - ollamaAPI.listModels().stream() - .anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel())); - assertTrue(found); - } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - fail(e); + @Test + @Order(1) + void testWrongEndpoint() { + OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434"); + assertThrows(ConnectException.class, ollamaAPI::listModels); } - } - @Test - @Order(3) - void testListDtails() { - testEndpointReachability(); - try { - ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel()); - assertNotNull(modelDetails); - System.out.println(modelDetails); - } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - fail(e); + @Test + @Order(1) + void testEndpointReachability() { + try { + assertNotNull(ollamaAPI.listModels()); + } catch (HttpConnectTimeoutException e) { + fail(e.getMessage()); + } catch (Exception e) { + fail(e); + } } - } - @Test - @Order(3) - void testAskModelWithDefaultOptions() { - testEndpointReachability(); - try { - OllamaResult result = - ollamaAPI.generate( - config.getModel(), - "What is the capital of France? And what's France's connection with Mona Lisa?", - new OptionsBuilder().build()); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + @Test + @Order(2) + void testListModels() { + testEndpointReachability(); + try { + assertNotNull(ollamaAPI.listModels()); + ollamaAPI.listModels().forEach(System.out::println); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + fail(e); + } } - } - @Test - @Order(3) - void testAskModelWithDefaultOptionsStreamed() { - testEndpointReachability(); - try { - - StringBuffer sb = new StringBuffer(""); - - OllamaResult result = ollamaAPI.generate(config.getModel(), - "What is the capital of France? And what's France's connection with Mona Lisa?", - new OptionsBuilder().build(), (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); - - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - assertEquals(sb.toString().trim(), result.getResponse().trim()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + @Test + @Order(2) + void testPullModel() { + testEndpointReachability(); + try { + ollamaAPI.pullModel(config.getModel()); + boolean found = + ollamaAPI.listModels().stream() + .anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel())); + assertTrue(found); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + fail(e); + } } - } - @Test - @Order(3) - void testAskModelWithOptions() { - testEndpointReachability(); - try { - OllamaResult result = - ollamaAPI.generate( - config.getModel(), - "What is the capital of France? And what's France's connection with Mona Lisa?", - new OptionsBuilder().setTemperature(0.9f).build()); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + @Test + @Order(3) + void testListDtails() { + testEndpointReachability(); + try { + ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel()); + assertNotNull(modelDetails); + System.out.println(modelDetails); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + fail(e); + } } - } - @Test - @Order(3) - void testChat() { - testEndpointReachability(); - try { - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); - OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?") - .withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!") - .withMessage(OllamaChatMessageRole.USER,"And what is the second larges city?") - .build(); - - OllamaChatResult chatResult = ollamaAPI.chat(requestModel); - assertNotNull(chatResult); - assertFalse(chatResult.getResponse().isBlank()); - assertEquals(4,chatResult.getChatHistory().size()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + @Test + @Order(3) + void testAskModelWithDefaultOptions() { + testEndpointReachability(); + try { + OllamaResult result = + ollamaAPI.generate( + config.getModel(), + "What is the capital of France? And what's France's connection with Mona Lisa?", + false, + new OptionsBuilder().build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } } - } - @Test - @Order(3) - void testChatWithSystemPrompt() { - testEndpointReachability(); - try { - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); - OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, - "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!") - .withMessage(OllamaChatMessageRole.USER, - "What is the capital of France? And what's France's connection with Mona Lisa?") - .build(); + @Test + @Order(3) + void testAskModelWithDefaultOptionsStreamed() { + testEndpointReachability(); + try { + StringBuffer sb = new StringBuffer(""); + OllamaResult result = ollamaAPI.generate(config.getModel(), + "What is the capital of France? And what's France's connection with Mona Lisa?", + false, + new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); - OllamaChatResult chatResult = ollamaAPI.chat(requestModel); - assertNotNull(chatResult); - assertFalse(chatResult.getResponse().isBlank()); - assertTrue(chatResult.getResponse().startsWith("NI")); - assertEquals(3, chatResult.getChatHistory().size()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertEquals(sb.toString().trim(), result.getResponse().trim()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } } - } - @Test - @Order(3) - void testChatWithStream() { - testEndpointReachability(); - try { - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); - OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, - "What is the capital of France? And what's France's connection with Mona Lisa?") - .build(); - - StringBuffer sb = new StringBuffer(""); - - OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); - assertNotNull(chatResult); - assertEquals(sb.toString().trim(), chatResult.getResponse().trim()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + @Test + @Order(3) + void testAskModelWithOptions() { + testEndpointReachability(); + try { + OllamaResult result = + ollamaAPI.generate( + config.getModel(), + "What is the capital of France? And what's France's connection with Mona Lisa?", + true, + new OptionsBuilder().setTemperature(0.9f).build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } } - } - @Test - @Order(3) - void testChatWithImageFromFileWithHistoryRecognition() { - testEndpointReachability(); - try { - OllamaChatRequestBuilder builder = - OllamaChatRequestBuilder.getInstance(config.getImageModel()); - OllamaChatRequestModel requestModel = - builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", - List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); + @Test + @Order(3) + void testChat() { + testEndpointReachability(); + try { + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); + OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?") + .withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!") + .withMessage(OllamaChatMessageRole.USER, "And what is the second larges city?") + .build(); - OllamaChatResult chatResult = ollamaAPI.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponse()); - - builder.reset(); - - requestModel = - builder.withMessages(chatResult.getChatHistory()) - .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); - - chatResult = ollamaAPI.chat(requestModel); - assertNotNull(chatResult); - assertNotNull(chatResult.getResponse()); - - - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); + assertFalse(chatResult.getResponse().isBlank()); + assertEquals(4, chatResult.getChatHistory().size()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } } - } - @Test - @Order(3) - void testChatWithImageFromURL() { - testEndpointReachability(); - try { - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel()); - OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", - "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") - .build(); + @Test + @Order(3) + void testChatWithSystemPrompt() { + testEndpointReachability(); + try { + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); + OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM, + "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!") + .withMessage(OllamaChatMessageRole.USER, + "What is the capital of France? And what's France's connection with Mona Lisa?") + .build(); - OllamaChatResult chatResult = ollamaAPI.chat(requestModel); - assertNotNull(chatResult); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); + assertFalse(chatResult.getResponse().isBlank()); + assertTrue(chatResult.getResponse().startsWith("NI")); + assertEquals(3, chatResult.getChatHistory().size()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } } - } - @Test - @Order(3) - void testAskModelWithOptionsAndImageFiles() { - testEndpointReachability(); - File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); - try { - OllamaResult result = - ollamaAPI.generateWithImageFiles( - config.getImageModel(), - "What is in this image?", - List.of(imageFile), - new OptionsBuilder().build()); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + @Test + @Order(3) + void testChatWithStream() { + testEndpointReachability(); + try { + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel()); + OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, + "What is the capital of France? And what's France's connection with Mona Lisa?") + .build(); + + StringBuffer sb = new StringBuffer(""); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(chatResult); + assertEquals(sb.toString().trim(), chatResult.getResponse().trim()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } } - } - @Test - @Order(3) - void testAskModelWithOptionsAndImageFilesStreamed() { - testEndpointReachability(); - File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); - try { - StringBuffer sb = new StringBuffer(""); + @Test + @Order(3) + void testChatWithImageFromFileWithHistoryRecognition() { + testEndpointReachability(); + try { + OllamaChatRequestBuilder builder = + OllamaChatRequestBuilder.getInstance(config.getImageModel()); + OllamaChatRequestModel requestModel = + builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", + List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build(); - OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(), - "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> { - LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); - sb.append(substring); - }); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - assertEquals(sb.toString().trim(), result.getResponse().trim()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponse()); + + builder.reset(); + + requestModel = + builder.withMessages(chatResult.getChatHistory()) + .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build(); + + chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponse()); + + + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } } - } - @Test - @Order(3) - void testAskModelWithOptionsAndImageURLs() { - testEndpointReachability(); - try { - OllamaResult result = - ollamaAPI.generateWithImageURLs( - config.getImageModel(), - "What is in this image?", - List.of( - "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"), - new OptionsBuilder().build()); - assertNotNull(result); - assertNotNull(result.getResponse()); - assertFalse(result.getResponse().isEmpty()); - } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - fail(e); + @Test + @Order(3) + void testChatWithImageFromURL() { + testEndpointReachability(); + try { + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel()); + OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?", + "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg") + .build(); + + OllamaChatResult chatResult = ollamaAPI.chat(requestModel); + assertNotNull(chatResult); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } } - } - @Test - @Order(3) - public void testEmbedding() { - testEndpointReachability(); - try { - OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder - .getInstance(config.getModel(), "What is the capital of France?").build(); - - List embeddings = ollamaAPI.generateEmbeddings(request); - - assertNotNull(embeddings); - assertFalse(embeddings.isEmpty()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - fail(e); + @Test + @Order(3) + void testAskModelWithOptionsAndImageFiles() { + testEndpointReachability(); + File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); + try { + OllamaResult result = + ollamaAPI.generateWithImageFiles( + config.getImageModel(), + "What is in this image?", + List.of(imageFile), + new OptionsBuilder().build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } + } + + @Test + @Order(3) + void testAskModelWithOptionsAndImageFilesStreamed() { + testEndpointReachability(); + File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg"); + try { + StringBuffer sb = new StringBuffer(""); + + OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(), + "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length(), s.length()); + LOG.info(substring); + sb.append(substring); + }); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + assertEquals(sb.toString().trim(), result.getResponse().trim()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } + } + + @Test + @Order(3) + void testAskModelWithOptionsAndImageURLs() { + testEndpointReachability(); + try { + OllamaResult result = + ollamaAPI.generateWithImageURLs( + config.getImageModel(), + "What is in this image?", + List.of( + "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"), + new OptionsBuilder().build()); + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + fail(e); + } + } + + @Test + @Order(3) + public void testEmbedding() { + testEndpointReachability(); + try { + OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder + .getInstance(config.getModel(), "What is the capital of France?").build(); + + List embeddings = ollamaAPI.generateEmbeddings(request); + + assertNotNull(embeddings); + assertFalse(embeddings.isEmpty()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + fail(e); + } } - } } @Data class Config { - private String ollamaURL; - private String model; - private String imageModel; - private int requestTimeoutSeconds; + private String ollamaURL; + private String model; + private String imageModel; + private int requestTimeoutSeconds; - public Config() { - Properties properties = new Properties(); - try (InputStream input = - getClass().getClassLoader().getResourceAsStream("test-config.properties")) { - if (input == null) { - throw new RuntimeException("Sorry, unable to find test-config.properties"); - } - properties.load(input); - this.ollamaURL = properties.getProperty("ollama.url"); - this.model = properties.getProperty("ollama.model"); - this.imageModel = properties.getProperty("ollama.model.image"); - this.requestTimeoutSeconds = - Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds")); - } catch (IOException e) { - throw new RuntimeException("Error loading properties", e); + public Config() { + Properties properties = new Properties(); + try (InputStream input = + getClass().getClassLoader().getResourceAsStream("test-config.properties")) { + if (input == null) { + throw new RuntimeException("Sorry, unable to find test-config.properties"); + } + properties.load(input); + this.ollamaURL = properties.getProperty("ollama.url"); + this.model = properties.getProperty("ollama.model"); + this.imageModel = properties.getProperty("ollama.model.image"); + this.requestTimeoutSeconds = + Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds")); + } catch (IOException e) { + throw new RuntimeException("Error loading properties", e); + } } - } } diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java index 879c67c..c5d60e1 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java @@ -1,7 +1,5 @@ package io.github.amithkoujalgi.ollama4j.unittests; -import static org.mockito.Mockito.*; - import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; @@ -9,155 +7,158 @@ import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback; import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + import java.io.IOException; import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Collections; -import org.junit.jupiter.api.Test; -import org.mockito.Mockito; + +import static org.mockito.Mockito.*; class TestMockedAPIs { - @Test - void testPullModel() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - try { - doNothing().when(ollamaAPI).pullModel(model); - ollamaAPI.pullModel(model); - verify(ollamaAPI, times(1)).pullModel(model); - } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - throw new RuntimeException(e); + @Test + void testPullModel() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + try { + doNothing().when(ollamaAPI).pullModel(model); + ollamaAPI.pullModel(model); + verify(ollamaAPI, times(1)).pullModel(model); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + throw new RuntimeException(e); + } } - } - @Test - void testListModels() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - try { - when(ollamaAPI.listModels()).thenReturn(new ArrayList<>()); - ollamaAPI.listModels(); - verify(ollamaAPI, times(1)).listModels(); - } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - throw new RuntimeException(e); + @Test + void testListModels() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + try { + when(ollamaAPI.listModels()).thenReturn(new ArrayList<>()); + ollamaAPI.listModels(); + verify(ollamaAPI, times(1)).listModels(); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + throw new RuntimeException(e); + } } - } - @Test - void testCreateModel() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros."; - try { - doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath); - ollamaAPI.createModelWithModelFileContents(model, modelFilePath); - verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath); - } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - throw new RuntimeException(e); + @Test + void testCreateModel() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros."; + try { + doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath); + ollamaAPI.createModelWithModelFileContents(model, modelFilePath); + verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + throw new RuntimeException(e); + } } - } - @Test - void testDeleteModel() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - try { - doNothing().when(ollamaAPI).deleteModel(model, true); - ollamaAPI.deleteModel(model, true); - verify(ollamaAPI, times(1)).deleteModel(model, true); - } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - throw new RuntimeException(e); + @Test + void testDeleteModel() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + try { + doNothing().when(ollamaAPI).deleteModel(model, true); + ollamaAPI.deleteModel(model, true); + verify(ollamaAPI, times(1)).deleteModel(model, true); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + throw new RuntimeException(e); + } } - } - @Test - void testGetModelDetails() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - try { - when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail()); - ollamaAPI.getModelDetails(model); - verify(ollamaAPI, times(1)).getModelDetails(model); - } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - throw new RuntimeException(e); + @Test + void testGetModelDetails() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + try { + when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail()); + ollamaAPI.getModelDetails(model); + verify(ollamaAPI, times(1)).getModelDetails(model); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + throw new RuntimeException(e); + } } - } - @Test - void testGenerateEmbeddings() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - String prompt = "some prompt text"; - try { - when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>()); - ollamaAPI.generateEmbeddings(model, prompt); - verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt); - } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + @Test + void testGenerateEmbeddings() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + String prompt = "some prompt text"; + try { + when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>()); + ollamaAPI.generateEmbeddings(model, prompt); + verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } } - } - @Test - void testAsk() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - String prompt = "some prompt text"; - OptionsBuilder optionsBuilder = new OptionsBuilder(); - try { - when(ollamaAPI.generate(model, prompt, optionsBuilder.build())) - .thenReturn(new OllamaResult("", 0, 200)); - ollamaAPI.generate(model, prompt, optionsBuilder.build()); - verify(ollamaAPI, times(1)).generate(model, prompt, optionsBuilder.build()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + @Test + void testAsk() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + String prompt = "some prompt text"; + OptionsBuilder optionsBuilder = new OptionsBuilder(); + try { + when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build())) + .thenReturn(new OllamaResult("", 0, 200)); + ollamaAPI.generate(model, prompt, false, optionsBuilder.build()); + verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } } - } - @Test - void testAskWithImageFiles() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - String prompt = "some prompt text"; - try { - when(ollamaAPI.generateWithImageFiles( - model, prompt, Collections.emptyList(), new OptionsBuilder().build())) - .thenReturn(new OllamaResult("", 0, 200)); - ollamaAPI.generateWithImageFiles( - model, prompt, Collections.emptyList(), new OptionsBuilder().build()); - verify(ollamaAPI, times(1)) - .generateWithImageFiles( - model, prompt, Collections.emptyList(), new OptionsBuilder().build()); - } catch (IOException | OllamaBaseException | InterruptedException e) { - throw new RuntimeException(e); + @Test + void testAskWithImageFiles() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + String prompt = "some prompt text"; + try { + when(ollamaAPI.generateWithImageFiles( + model, prompt, Collections.emptyList(), new OptionsBuilder().build())) + .thenReturn(new OllamaResult("", 0, 200)); + ollamaAPI.generateWithImageFiles( + model, prompt, Collections.emptyList(), new OptionsBuilder().build()); + verify(ollamaAPI, times(1)) + .generateWithImageFiles( + model, prompt, Collections.emptyList(), new OptionsBuilder().build()); + } catch (IOException | OllamaBaseException | InterruptedException e) { + throw new RuntimeException(e); + } } - } - @Test - void testAskWithImageURLs() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - String prompt = "some prompt text"; - try { - when(ollamaAPI.generateWithImageURLs( - model, prompt, Collections.emptyList(), new OptionsBuilder().build())) - .thenReturn(new OllamaResult("", 0, 200)); - ollamaAPI.generateWithImageURLs( - model, prompt, Collections.emptyList(), new OptionsBuilder().build()); - verify(ollamaAPI, times(1)) - .generateWithImageURLs( - model, prompt, Collections.emptyList(), new OptionsBuilder().build()); - } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { - throw new RuntimeException(e); + @Test + void testAskWithImageURLs() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + String prompt = "some prompt text"; + try { + when(ollamaAPI.generateWithImageURLs( + model, prompt, Collections.emptyList(), new OptionsBuilder().build())) + .thenReturn(new OllamaResult("", 0, 200)); + ollamaAPI.generateWithImageURLs( + model, prompt, Collections.emptyList(), new OptionsBuilder().build()); + verify(ollamaAPI, times(1)) + .generateWithImageURLs( + model, prompt, Collections.emptyList(), new OptionsBuilder().build()); + } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { + throw new RuntimeException(e); + } } - } - @Test - void testAskAsync() { - OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - String model = OllamaModelType.LLAMA2; - String prompt = "some prompt text"; - when(ollamaAPI.generateAsync(model, prompt)) - .thenReturn(new OllamaAsyncResultCallback(null, null, 3)); - ollamaAPI.generateAsync(model, prompt); - verify(ollamaAPI, times(1)).generateAsync(model, prompt); - } + @Test + void testAskAsync() { + OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); + String model = OllamaModelType.LLAMA2; + String prompt = "some prompt text"; + when(ollamaAPI.generateAsync(model, prompt, false)) + .thenReturn(new OllamaAsyncResultCallback(null, null, 3)); + ollamaAPI.generateAsync(model, prompt, false); + verify(ollamaAPI, times(1)).generateAsync(model, prompt, false); + } }