arguments` to the tool functions, we could support passing
+specific args separately with their data types. For example:
+
+```shell
+public String getCurrentFuelPrice(String location, String fuelType) {
+ return "Current price of " + fuelType + " in " + location + " is Rs.103/L";
+}
+```
+
+Updating async/chat APIs with support for tool-based generation.
\ No newline at end of file
diff --git a/docs/docs/apis-generate/prompt-builder.md b/docs/docs/apis-generate/prompt-builder.md
index a798808..ffe57d7 100644
--- a/docs/docs/apis-generate/prompt-builder.md
+++ b/docs/docs/apis-generate/prompt-builder.md
@@ -1,5 +1,5 @@
---
-sidebar_position: 5
+sidebar_position: 6
---
# Prompt Builder
diff --git a/pom.xml b/pom.xml
index 3e73414..2892ed2 100644
--- a/pom.xml
+++ b/pom.xml
@@ -4,7 +4,7 @@
io.github.amithkoujalgi
ollama4j
- 1.0.74-SNAPSHOT
+ 1.0.78-SNAPSHOT
Ollama4j
Java library for interacting with Ollama API.
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
index 1f22210..80654ae 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java
@@ -10,6 +10,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingRe
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.request.*;
+import io.github.amithkoujalgi.ollama4j.core.tools.*;
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;
@@ -25,9 +26,7 @@ import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.time.Duration;
-import java.util.ArrayList;
-import java.util.Base64;
-import java.util.List;
+import java.util.*;
/**
* The base Ollama API class.
@@ -339,6 +338,7 @@ public class OllamaAPI {
}
}
+
/**
* Generate response for a question to a model running on Ollama server. This is a sync/blocking
* call.
@@ -351,9 +351,10 @@ public class OllamaAPI {
* @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
* @return OllamaResult that includes response text and time taken for response
*/
- public OllamaResult generate(String model, String prompt, Options options, OllamaStreamHandler streamHandler)
+ public OllamaResult generate(String model, String prompt, boolean raw, Options options, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
+ ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setOptions(options.getOptionsMap());
return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
}
@@ -361,13 +362,37 @@ public class OllamaAPI {
/**
* Convenience method to call Ollama API without streaming responses.
*
- * Uses {@link #generate(String, String, Options, OllamaStreamHandler)}
+ * Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
+ *
+ * @param model Model to use
+ * @param prompt Prompt text
+ * @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context.
+ * @param options Additional Options
+ * @return OllamaResult
*/
- public OllamaResult generate(String model, String prompt, Options options)
+ public OllamaResult generate(String model, String prompt, boolean raw, Options options)
throws OllamaBaseException, IOException, InterruptedException {
- return generate(model, prompt, options, null);
+ return generate(model, prompt, raw, options, null);
}
+
+ public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options)
+ throws OllamaBaseException, IOException, InterruptedException {
+ OllamaToolsResult toolResult = new OllamaToolsResult();
+ Map toolResults = new HashMap<>();
+
+ OllamaResult result = generate(model, prompt, raw, options, null);
+ toolResult.setModelResult(result);
+
+ List toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class));
+ for (ToolDef toolDef : toolDefs) {
+ toolResults.put(toolDef, invokeTool(toolDef));
+ }
+ toolResult.setToolResults(toolResults);
+ return toolResult;
+ }
+
+
/**
* Generate response for a question to a model running on Ollama server and get a callback handle
* that can be used to check for status and get the response from the model later. This would be
@@ -377,9 +402,9 @@ public class OllamaAPI {
* @param prompt the prompt/question text
* @return the ollama async result callback handle
*/
- public OllamaAsyncResultCallback generateAsync(String model, String prompt) {
+ public OllamaAsyncResultCallback generateAsync(String model, String prompt, boolean raw) {
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
-
+ ollamaRequestModel.setRaw(raw);
URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultCallback ollamaAsyncResultCallback =
new OllamaAsyncResultCallback(
@@ -576,4 +601,24 @@ public class OllamaAPI {
private boolean isBasicAuthCredentialsSet() {
return basicAuth != null;
}
+
+
+ public void registerTool(MistralTools.ToolSpecification toolSpecification) {
+ ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
+ }
+
+ private Object invokeTool(ToolDef toolDef) {
+ try {
+ String methodName = toolDef.getName();
+ Map arguments = toolDef.getArguments();
+ DynamicFunction function = ToolRegistry.getFunction(methodName);
+ if (function == null) {
+ throw new IllegalArgumentException("No such tool: " + methodName);
+ }
+ return function.apply(arguments);
+ } catch (Exception e) {
+ e.printStackTrace();
+ return "Error calling tool: " + e.getMessage();
+ }
+ }
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
index fe7fbec..d3d71e4 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
@@ -1,9 +1,5 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
-import java.io.IOException;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
@@ -13,15 +9,19 @@ import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRespo
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
+import java.io.IOException;
+
+public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
private OllamaGenerateStreamObserver streamObserver;
public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
- super(host, basicAuth, requestTimeoutSeconds, verbose);
+ super(host, basicAuth, requestTimeoutSeconds, verbose);
}
@Override
@@ -31,24 +31,22 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
@Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
- try {
- OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
- responseBuffer.append(ollamaResponseModel.getResponse());
- if(streamObserver != null) {
- streamObserver.notify(ollamaResponseModel);
- }
- return ollamaResponseModel.isDone();
- } catch (JsonProcessingException e) {
- LOG.error("Error parsing the Ollama chat response!",e);
- return true;
- }
+ try {
+ OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
+ responseBuffer.append(ollamaResponseModel.getResponse());
+ if (streamObserver != null) {
+ streamObserver.notify(ollamaResponseModel);
+ }
+ return ollamaResponseModel.isDone();
+ } catch (JsonProcessingException e) {
+ LOG.error("Error parsing the Ollama chat response!", e);
+ return true;
+ }
}
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
- throws OllamaBaseException, IOException, InterruptedException {
- streamObserver = new OllamaGenerateStreamObserver(streamHandler);
- return super.callSync(body);
+ throws OllamaBaseException, IOException, InterruptedException {
+ streamObserver = new OllamaGenerateStreamObserver(streamHandler);
+ return super.callSync(body);
}
-
-
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/DynamicFunction.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/DynamicFunction.java
new file mode 100644
index 0000000..5b8f5e6
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/DynamicFunction.java
@@ -0,0 +1,8 @@
+package io.github.amithkoujalgi.ollama4j.core.tools;
+
+import java.util.Map;
+
+@FunctionalInterface
+public interface DynamicFunction {
+ Object apply(Map arguments);
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/MistralTools.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/MistralTools.java
new file mode 100644
index 0000000..fff8071
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/MistralTools.java
@@ -0,0 +1,139 @@
+package io.github.amithkoujalgi.ollama4j.core.tools;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
+import lombok.Builder;
+import lombok.Data;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class MistralTools {
+ @Data
+ @Builder
+ public static class ToolSpecification {
+ private String functionName;
+ private String functionDesc;
+ private Map props;
+ private DynamicFunction toolDefinition;
+ }
+
+ @Data
+ @JsonIgnoreProperties(ignoreUnknown = true)
+ public static class PromptFuncDefinition {
+ private String type;
+ private PromptFuncSpec function;
+
+ @Data
+ public static class PromptFuncSpec {
+ private String name;
+ private String description;
+ private Parameters parameters;
+ }
+
+ @Data
+ public static class Parameters {
+ private String type;
+ private Map properties;
+ private List required;
+ }
+
+ @Data
+ @Builder
+ public static class Property {
+ private String type;
+ private String description;
+ @JsonProperty("enum")
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ private List enumValues;
+ @JsonIgnore
+ private boolean required;
+ }
+ }
+
+ public static class PropsBuilder {
+ private final Map props = new HashMap<>();
+
+ public PropsBuilder withProperty(String key, PromptFuncDefinition.Property property) {
+ props.put(key, property);
+ return this;
+ }
+
+ public Map build() {
+ return props;
+ }
+ }
+
+ public static class PromptBuilder {
+ private final List tools = new ArrayList<>();
+
+ private String promptText;
+
+ public String build() throws JsonProcessingException {
+ return "[AVAILABLE_TOOLS] " + Utils.getObjectMapper().writeValueAsString(tools) + "[/AVAILABLE_TOOLS][INST] " + promptText + " [/INST]";
+ }
+
+ public PromptBuilder withPrompt(String prompt) throws JsonProcessingException {
+ promptText = prompt;
+ return this;
+ }
+
+ public PromptBuilder withToolSpecification(ToolSpecification spec) {
+ PromptFuncDefinition def = new PromptFuncDefinition();
+ def.setType("function");
+
+ PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
+ functionDetail.setName(spec.getFunctionName());
+ functionDetail.setDescription(spec.getFunctionDesc());
+
+ PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
+ parameters.setType("object");
+ parameters.setProperties(spec.getProps());
+
+ List requiredValues = new ArrayList<>();
+ for (Map.Entry p : spec.getProps().entrySet()) {
+ if (p.getValue().isRequired()) {
+ requiredValues.add(p.getKey());
+ }
+ }
+ parameters.setRequired(requiredValues);
+ functionDetail.setParameters(parameters);
+ def.setFunction(functionDetail);
+
+ tools.add(def);
+ return this;
+ }
+//
+// public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map props) {
+// PromptFuncDefinition def = new PromptFuncDefinition();
+// def.setType("function");
+//
+// PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
+// functionDetail.setName(functionName);
+// functionDetail.setDescription(functionDesc);
+//
+// PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
+// parameters.setType("object");
+// parameters.setProperties(props);
+//
+// List requiredValues = new ArrayList<>();
+// for (Map.Entry p : props.entrySet()) {
+// if (p.getValue().isRequired()) {
+// requiredValues.add(p.getKey());
+// }
+// }
+// parameters.setRequired(requiredValues);
+// functionDetail.setParameters(parameters);
+// def.setFunction(functionDetail);
+//
+// tools.add(def);
+// return this;
+// }
+ }
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java
new file mode 100644
index 0000000..65ef3ac
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java
@@ -0,0 +1,16 @@
+package io.github.amithkoujalgi.ollama4j.core.tools;
+
+import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.Map;
+
+@Data
+@NoArgsConstructor
+@AllArgsConstructor
+public class OllamaToolsResult {
+ private OllamaResult modelResult;
+ private Map toolResults;
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolDef.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolDef.java
new file mode 100644
index 0000000..751d186
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolDef.java
@@ -0,0 +1,18 @@
+package io.github.amithkoujalgi.ollama4j.core.tools;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.Map;
+
+@Data
+@AllArgsConstructor
+@NoArgsConstructor
+public class ToolDef {
+
+ private String name;
+ private Map arguments;
+
+}
+
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java
new file mode 100644
index 0000000..0004c7f
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java
@@ -0,0 +1,17 @@
+package io.github.amithkoujalgi.ollama4j.core.tools;
+
+import java.util.HashMap;
+import java.util.Map;
+
+public class ToolRegistry {
+ private static final Map functionMap = new HashMap<>();
+
+
+ public static DynamicFunction getFunction(String name) {
+ return functionMap.get(name);
+ }
+
+ public static void addFunction(String name, DynamicFunction function) {
+ functionMap.put(name, function);
+ }
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java
index 5733979..2613d86 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java
@@ -9,6 +9,9 @@ package io.github.amithkoujalgi.ollama4j.core.types;
@SuppressWarnings("ALL")
public class OllamaModelType {
public static final String GEMMA = "gemma";
+ public static final String GEMMA2 = "gemma2";
+
+
public static final String LLAMA2 = "llama2";
public static final String LLAMA3 = "llama3";
public static final String MISTRAL = "mistral";
@@ -30,6 +33,8 @@ public class OllamaModelType {
public static final String ZEPHYR = "zephyr";
public static final String OPENHERMES = "openhermes";
public static final String QWEN = "qwen";
+
+ public static final String QWEN2 = "qwen2";
public static final String WIZARDCODER = "wizardcoder";
public static final String LLAMA2_CHINESE = "llama2-chinese";
public static final String TINYLLAMA = "tinyllama";
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
index d822077..58e55a1 100644
--- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
@@ -1,7 +1,5 @@
package io.github.amithkoujalgi.ollama4j.integrationtests;
-import static org.junit.jupiter.api.Assertions.*;
-
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
@@ -10,9 +8,16 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
-import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
+import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
+import lombok.Data;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Order;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
@@ -22,372 +27,369 @@ import java.net.http.HttpConnectTimeoutException;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
-import lombok.Data;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Order;
-import org.junit.jupiter.api.Test;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+
+import static org.junit.jupiter.api.Assertions.*;
class TestRealAPIs {
- private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
+ private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
- OllamaAPI ollamaAPI;
- Config config;
+ OllamaAPI ollamaAPI;
+ Config config;
- private File getImageFileFromClasspath(String fileName) {
- ClassLoader classLoader = getClass().getClassLoader();
- return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
- }
-
- @BeforeEach
- void setUp() {
- config = new Config();
- ollamaAPI = new OllamaAPI(config.getOllamaURL());
- ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
- }
-
- @Test
- @Order(1)
- void testWrongEndpoint() {
- OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
- assertThrows(ConnectException.class, ollamaAPI::listModels);
- }
-
- @Test
- @Order(1)
- void testEndpointReachability() {
- try {
- assertNotNull(ollamaAPI.listModels());
- } catch (HttpConnectTimeoutException e) {
- fail(e.getMessage());
- } catch (Exception e) {
- fail(e);
+ private File getImageFileFromClasspath(String fileName) {
+ ClassLoader classLoader = getClass().getClassLoader();
+ return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
}
- }
- @Test
- @Order(2)
- void testListModels() {
- testEndpointReachability();
- try {
- assertNotNull(ollamaAPI.listModels());
- ollamaAPI.listModels().forEach(System.out::println);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- fail(e);
+ @BeforeEach
+ void setUp() {
+ config = new Config();
+ ollamaAPI = new OllamaAPI(config.getOllamaURL());
+ ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
}
- }
- @Test
- @Order(2)
- void testPullModel() {
- testEndpointReachability();
- try {
- ollamaAPI.pullModel(config.getModel());
- boolean found =
- ollamaAPI.listModels().stream()
- .anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
- assertTrue(found);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- fail(e);
+ @Test
+ @Order(1)
+ void testWrongEndpoint() {
+ OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
+ assertThrows(ConnectException.class, ollamaAPI::listModels);
}
- }
- @Test
- @Order(3)
- void testListDtails() {
- testEndpointReachability();
- try {
- ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel());
- assertNotNull(modelDetails);
- System.out.println(modelDetails);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- fail(e);
+ @Test
+ @Order(1)
+ void testEndpointReachability() {
+ try {
+ assertNotNull(ollamaAPI.listModels());
+ } catch (HttpConnectTimeoutException e) {
+ fail(e.getMessage());
+ } catch (Exception e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testAskModelWithDefaultOptions() {
- testEndpointReachability();
- try {
- OllamaResult result =
- ollamaAPI.generate(
- config.getModel(),
- "What is the capital of France? And what's France's connection with Mona Lisa?",
- new OptionsBuilder().build());
- assertNotNull(result);
- assertNotNull(result.getResponse());
- assertFalse(result.getResponse().isEmpty());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- fail(e);
+ @Test
+ @Order(2)
+ void testListModels() {
+ testEndpointReachability();
+ try {
+ assertNotNull(ollamaAPI.listModels());
+ ollamaAPI.listModels().forEach(System.out::println);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testAskModelWithDefaultOptionsStreamed() {
- testEndpointReachability();
- try {
-
- StringBuffer sb = new StringBuffer("");
-
- OllamaResult result = ollamaAPI.generate(config.getModel(),
- "What is the capital of France? And what's France's connection with Mona Lisa?",
- new OptionsBuilder().build(), (s) -> {
- LOG.info(s);
- String substring = s.substring(sb.toString().length(), s.length());
- LOG.info(substring);
- sb.append(substring);
- });
-
- assertNotNull(result);
- assertNotNull(result.getResponse());
- assertFalse(result.getResponse().isEmpty());
- assertEquals(sb.toString().trim(), result.getResponse().trim());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- fail(e);
+ @Test
+ @Order(2)
+ void testPullModel() {
+ testEndpointReachability();
+ try {
+ ollamaAPI.pullModel(config.getModel());
+ boolean found =
+ ollamaAPI.listModels().stream()
+ .anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
+ assertTrue(found);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testAskModelWithOptions() {
- testEndpointReachability();
- try {
- OllamaResult result =
- ollamaAPI.generate(
- config.getModel(),
- "What is the capital of France? And what's France's connection with Mona Lisa?",
- new OptionsBuilder().setTemperature(0.9f).build());
- assertNotNull(result);
- assertNotNull(result.getResponse());
- assertFalse(result.getResponse().isEmpty());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- fail(e);
+ @Test
+ @Order(3)
+ void testListDtails() {
+ testEndpointReachability();
+ try {
+ ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel());
+ assertNotNull(modelDetails);
+ System.out.println(modelDetails);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testChat() {
- testEndpointReachability();
- try {
- OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
- OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
- .withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
- .withMessage(OllamaChatMessageRole.USER,"And what is the second larges city?")
- .build();
-
- OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
- assertNotNull(chatResult);
- assertFalse(chatResult.getResponse().isBlank());
- assertEquals(4,chatResult.getChatHistory().size());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- fail(e);
+ @Test
+ @Order(3)
+ void testAskModelWithDefaultOptions() {
+ testEndpointReachability();
+ try {
+ OllamaResult result =
+ ollamaAPI.generate(
+ config.getModel(),
+ "What is the capital of France? And what's France's connection with Mona Lisa?",
+ false,
+ new OptionsBuilder().build());
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testChatWithSystemPrompt() {
- testEndpointReachability();
- try {
- OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
- OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
- "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
- .withMessage(OllamaChatMessageRole.USER,
- "What is the capital of France? And what's France's connection with Mona Lisa?")
- .build();
+ @Test
+ @Order(3)
+ void testAskModelWithDefaultOptionsStreamed() {
+ testEndpointReachability();
+ try {
+ StringBuffer sb = new StringBuffer("");
+ OllamaResult result = ollamaAPI.generate(config.getModel(),
+ "What is the capital of France? And what's France's connection with Mona Lisa?",
+ false,
+ new OptionsBuilder().build(), (s) -> {
+ LOG.info(s);
+ String substring = s.substring(sb.toString().length(), s.length());
+ LOG.info(substring);
+ sb.append(substring);
+ });
- OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
- assertNotNull(chatResult);
- assertFalse(chatResult.getResponse().isBlank());
- assertTrue(chatResult.getResponse().startsWith("NI"));
- assertEquals(3, chatResult.getChatHistory().size());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- fail(e);
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ assertEquals(sb.toString().trim(), result.getResponse().trim());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testChatWithStream() {
- testEndpointReachability();
- try {
- OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
- OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
- "What is the capital of France? And what's France's connection with Mona Lisa?")
- .build();
-
- StringBuffer sb = new StringBuffer("");
-
- OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> {
- LOG.info(s);
- String substring = s.substring(sb.toString().length(), s.length());
- LOG.info(substring);
- sb.append(substring);
- });
- assertNotNull(chatResult);
- assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- fail(e);
+ @Test
+ @Order(3)
+ void testAskModelWithOptions() {
+ testEndpointReachability();
+ try {
+ OllamaResult result =
+ ollamaAPI.generate(
+ config.getModel(),
+ "What is the capital of France? And what's France's connection with Mona Lisa?",
+ true,
+ new OptionsBuilder().setTemperature(0.9f).build());
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testChatWithImageFromFileWithHistoryRecognition() {
- testEndpointReachability();
- try {
- OllamaChatRequestBuilder builder =
- OllamaChatRequestBuilder.getInstance(config.getImageModel());
- OllamaChatRequestModel requestModel =
- builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
- List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
+ @Test
+ @Order(3)
+ void testChat() {
+ testEndpointReachability();
+ try {
+ OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
+ OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
+ .withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
+ .withMessage(OllamaChatMessageRole.USER, "And what is the second larges city?")
+ .build();
- OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
- assertNotNull(chatResult);
- assertNotNull(chatResult.getResponse());
-
- builder.reset();
-
- requestModel =
- builder.withMessages(chatResult.getChatHistory())
- .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
-
- chatResult = ollamaAPI.chat(requestModel);
- assertNotNull(chatResult);
- assertNotNull(chatResult.getResponse());
-
-
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- fail(e);
+ OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
+ assertNotNull(chatResult);
+ assertFalse(chatResult.getResponse().isBlank());
+ assertEquals(4, chatResult.getChatHistory().size());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testChatWithImageFromURL() {
- testEndpointReachability();
- try {
- OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
- OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
- "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
- .build();
+ @Test
+ @Order(3)
+ void testChatWithSystemPrompt() {
+ testEndpointReachability();
+ try {
+ OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
+ OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
+ "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
+ .withMessage(OllamaChatMessageRole.USER,
+ "What is the capital of France? And what's France's connection with Mona Lisa?")
+ .build();
- OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
- assertNotNull(chatResult);
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- fail(e);
+ OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
+ assertNotNull(chatResult);
+ assertFalse(chatResult.getResponse().isBlank());
+ assertTrue(chatResult.getResponse().startsWith("NI"));
+ assertEquals(3, chatResult.getChatHistory().size());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testAskModelWithOptionsAndImageFiles() {
- testEndpointReachability();
- File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
- try {
- OllamaResult result =
- ollamaAPI.generateWithImageFiles(
- config.getImageModel(),
- "What is in this image?",
- List.of(imageFile),
- new OptionsBuilder().build());
- assertNotNull(result);
- assertNotNull(result.getResponse());
- assertFalse(result.getResponse().isEmpty());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- fail(e);
+ @Test
+ @Order(3)
+ void testChatWithStream() {
+ testEndpointReachability();
+ try {
+ OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
+ OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
+ "What is the capital of France? And what's France's connection with Mona Lisa?")
+ .build();
+
+ StringBuffer sb = new StringBuffer("");
+
+ OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
+ LOG.info(s);
+ String substring = s.substring(sb.toString().length(), s.length());
+ LOG.info(substring);
+ sb.append(substring);
+ });
+ assertNotNull(chatResult);
+ assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testAskModelWithOptionsAndImageFilesStreamed() {
- testEndpointReachability();
- File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
- try {
- StringBuffer sb = new StringBuffer("");
+ @Test
+ @Order(3)
+ void testChatWithImageFromFileWithHistoryRecognition() {
+ testEndpointReachability();
+ try {
+ OllamaChatRequestBuilder builder =
+ OllamaChatRequestBuilder.getInstance(config.getImageModel());
+ OllamaChatRequestModel requestModel =
+ builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
+ List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
- OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
- "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
- LOG.info(s);
- String substring = s.substring(sb.toString().length(), s.length());
- LOG.info(substring);
- sb.append(substring);
- });
- assertNotNull(result);
- assertNotNull(result.getResponse());
- assertFalse(result.getResponse().isEmpty());
- assertEquals(sb.toString().trim(), result.getResponse().trim());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- fail(e);
+ OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
+ assertNotNull(chatResult);
+ assertNotNull(chatResult.getResponse());
+
+ builder.reset();
+
+ requestModel =
+ builder.withMessages(chatResult.getChatHistory())
+ .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
+
+ chatResult = ollamaAPI.chat(requestModel);
+ assertNotNull(chatResult);
+ assertNotNull(chatResult.getResponse());
+
+
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testAskModelWithOptionsAndImageURLs() {
- testEndpointReachability();
- try {
- OllamaResult result =
- ollamaAPI.generateWithImageURLs(
- config.getImageModel(),
- "What is in this image?",
- List.of(
- "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
- new OptionsBuilder().build());
- assertNotNull(result);
- assertNotNull(result.getResponse());
- assertFalse(result.getResponse().isEmpty());
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- fail(e);
+ @Test
+ @Order(3)
+ void testChatWithImageFromURL() {
+ testEndpointReachability();
+ try {
+ OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
+ OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
+ "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
+ .build();
+
+ OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
+ assertNotNull(chatResult);
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- public void testEmbedding() {
- testEndpointReachability();
- try {
- OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
- .getInstance(config.getModel(), "What is the capital of France?").build();
-
- List embeddings = ollamaAPI.generateEmbeddings(request);
-
- assertNotNull(embeddings);
- assertFalse(embeddings.isEmpty());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- fail(e);
+ @Test
+ @Order(3)
+ void testAskModelWithOptionsAndImageFiles() {
+ testEndpointReachability();
+ File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
+ try {
+ OllamaResult result =
+ ollamaAPI.generateWithImageFiles(
+ config.getImageModel(),
+ "What is in this image?",
+ List.of(imageFile),
+ new OptionsBuilder().build());
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
+ }
+
+ @Test
+ @Order(3)
+ void testAskModelWithOptionsAndImageFilesStreamed() {
+ testEndpointReachability();
+ File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
+ try {
+ StringBuffer sb = new StringBuffer("");
+
+ OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
+ "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
+ LOG.info(s);
+ String substring = s.substring(sb.toString().length(), s.length());
+ LOG.info(substring);
+ sb.append(substring);
+ });
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ assertEquals(sb.toString().trim(), result.getResponse().trim());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
+ }
+
+ @Test
+ @Order(3)
+ void testAskModelWithOptionsAndImageURLs() {
+ testEndpointReachability();
+ try {
+ OllamaResult result =
+ ollamaAPI.generateWithImageURLs(
+ config.getImageModel(),
+ "What is in this image?",
+ List.of(
+ "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
+ new OptionsBuilder().build());
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ fail(e);
+ }
+ }
+
+ @Test
+ @Order(3)
+ public void testEmbedding() {
+ testEndpointReachability();
+ try {
+ OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
+ .getInstance(config.getModel(), "What is the capital of France?").build();
+
+ List embeddings = ollamaAPI.generateEmbeddings(request);
+
+ assertNotNull(embeddings);
+ assertFalse(embeddings.isEmpty());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
}
@Data
class Config {
- private String ollamaURL;
- private String model;
- private String imageModel;
- private int requestTimeoutSeconds;
+ private String ollamaURL;
+ private String model;
+ private String imageModel;
+ private int requestTimeoutSeconds;
- public Config() {
- Properties properties = new Properties();
- try (InputStream input =
- getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
- if (input == null) {
- throw new RuntimeException("Sorry, unable to find test-config.properties");
- }
- properties.load(input);
- this.ollamaURL = properties.getProperty("ollama.url");
- this.model = properties.getProperty("ollama.model");
- this.imageModel = properties.getProperty("ollama.model.image");
- this.requestTimeoutSeconds =
- Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
- } catch (IOException e) {
- throw new RuntimeException("Error loading properties", e);
+ public Config() {
+ Properties properties = new Properties();
+ try (InputStream input =
+ getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
+ if (input == null) {
+ throw new RuntimeException("Sorry, unable to find test-config.properties");
+ }
+ properties.load(input);
+ this.ollamaURL = properties.getProperty("ollama.url");
+ this.model = properties.getProperty("ollama.model");
+ this.imageModel = properties.getProperty("ollama.model.image");
+ this.requestTimeoutSeconds =
+ Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
+ } catch (IOException e) {
+ throw new RuntimeException("Error loading properties", e);
+ }
}
- }
}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java
index 879c67c..c5d60e1 100644
--- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java
@@ -1,7 +1,5 @@
package io.github.amithkoujalgi.ollama4j.unittests;
-import static org.mockito.Mockito.*;
-
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
@@ -9,155 +7,158 @@ import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collections;
-import org.junit.jupiter.api.Test;
-import org.mockito.Mockito;
+
+import static org.mockito.Mockito.*;
class TestMockedAPIs {
- @Test
- void testPullModel() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- try {
- doNothing().when(ollamaAPI).pullModel(model);
- ollamaAPI.pullModel(model);
- verify(ollamaAPI, times(1)).pullModel(model);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ void testPullModel() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ try {
+ doNothing().when(ollamaAPI).pullModel(model);
+ ollamaAPI.pullModel(model);
+ verify(ollamaAPI, times(1)).pullModel(model);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testListModels() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- try {
- when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
- ollamaAPI.listModels();
- verify(ollamaAPI, times(1)).listModels();
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ void testListModels() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ try {
+ when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
+ ollamaAPI.listModels();
+ verify(ollamaAPI, times(1)).listModels();
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testCreateModel() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
- try {
- doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
- ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
- verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ void testCreateModel() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
+ try {
+ doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
+ ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
+ verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testDeleteModel() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- try {
- doNothing().when(ollamaAPI).deleteModel(model, true);
- ollamaAPI.deleteModel(model, true);
- verify(ollamaAPI, times(1)).deleteModel(model, true);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ void testDeleteModel() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ try {
+ doNothing().when(ollamaAPI).deleteModel(model, true);
+ ollamaAPI.deleteModel(model, true);
+ verify(ollamaAPI, times(1)).deleteModel(model, true);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testGetModelDetails() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- try {
- when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
- ollamaAPI.getModelDetails(model);
- verify(ollamaAPI, times(1)).getModelDetails(model);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ void testGetModelDetails() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ try {
+ when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
+ ollamaAPI.getModelDetails(model);
+ verify(ollamaAPI, times(1)).getModelDetails(model);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testGenerateEmbeddings() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- String prompt = "some prompt text";
- try {
- when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
- ollamaAPI.generateEmbeddings(model, prompt);
- verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ @Test
+ void testGenerateEmbeddings() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ String prompt = "some prompt text";
+ try {
+ when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
+ ollamaAPI.generateEmbeddings(model, prompt);
+ verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testAsk() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- String prompt = "some prompt text";
- OptionsBuilder optionsBuilder = new OptionsBuilder();
- try {
- when(ollamaAPI.generate(model, prompt, optionsBuilder.build()))
- .thenReturn(new OllamaResult("", 0, 200));
- ollamaAPI.generate(model, prompt, optionsBuilder.build());
- verify(ollamaAPI, times(1)).generate(model, prompt, optionsBuilder.build());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ @Test
+ void testAsk() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ String prompt = "some prompt text";
+ OptionsBuilder optionsBuilder = new OptionsBuilder();
+ try {
+ when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build()))
+ .thenReturn(new OllamaResult("", 0, 200));
+ ollamaAPI.generate(model, prompt, false, optionsBuilder.build());
+ verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testAskWithImageFiles() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- String prompt = "some prompt text";
- try {
- when(ollamaAPI.generateWithImageFiles(
- model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
- .thenReturn(new OllamaResult("", 0, 200));
- ollamaAPI.generateWithImageFiles(
- model, prompt, Collections.emptyList(), new OptionsBuilder().build());
- verify(ollamaAPI, times(1))
- .generateWithImageFiles(
- model, prompt, Collections.emptyList(), new OptionsBuilder().build());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ @Test
+ void testAskWithImageFiles() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ String prompt = "some prompt text";
+ try {
+ when(ollamaAPI.generateWithImageFiles(
+ model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
+ .thenReturn(new OllamaResult("", 0, 200));
+ ollamaAPI.generateWithImageFiles(
+ model, prompt, Collections.emptyList(), new OptionsBuilder().build());
+ verify(ollamaAPI, times(1))
+ .generateWithImageFiles(
+ model, prompt, Collections.emptyList(), new OptionsBuilder().build());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testAskWithImageURLs() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- String prompt = "some prompt text";
- try {
- when(ollamaAPI.generateWithImageURLs(
- model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
- .thenReturn(new OllamaResult("", 0, 200));
- ollamaAPI.generateWithImageURLs(
- model, prompt, Collections.emptyList(), new OptionsBuilder().build());
- verify(ollamaAPI, times(1))
- .generateWithImageURLs(
- model, prompt, Collections.emptyList(), new OptionsBuilder().build());
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ void testAskWithImageURLs() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ String prompt = "some prompt text";
+ try {
+ when(ollamaAPI.generateWithImageURLs(
+ model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
+ .thenReturn(new OllamaResult("", 0, 200));
+ ollamaAPI.generateWithImageURLs(
+ model, prompt, Collections.emptyList(), new OptionsBuilder().build());
+ verify(ollamaAPI, times(1))
+ .generateWithImageURLs(
+ model, prompt, Collections.emptyList(), new OptionsBuilder().build());
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testAskAsync() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- String prompt = "some prompt text";
- when(ollamaAPI.generateAsync(model, prompt))
- .thenReturn(new OllamaAsyncResultCallback(null, null, 3));
- ollamaAPI.generateAsync(model, prompt);
- verify(ollamaAPI, times(1)).generateAsync(model, prompt);
- }
+ @Test
+ void testAskAsync() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ String prompt = "some prompt text";
+ when(ollamaAPI.generateAsync(model, prompt, false))
+ .thenReturn(new OllamaAsyncResultCallback(null, null, 3));
+ ollamaAPI.generateAsync(model, prompt, false);
+ verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
+ }
}