imageURLs,
- Options options)
- throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
- return generateWithImageURLs(model, prompt, imageURLs, options, null);
- }
+ /**
+ * Convenience method to call Ollama API without streaming responses.
+ *
+ * Uses {@link #generate(String, String, boolean, Options, OllamaStreamHandler)}
+ *
+ * @param model Model to use
+ * @param prompt Prompt text
+ * @param raw In some cases, you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable templating. Also note that raw mode will not return a context.
+ * @param options Additional Options
+ * @return OllamaResult
+ */
+ public OllamaResult generate(String model, String prompt, boolean raw, Options options)
+ throws OllamaBaseException, IOException, InterruptedException {
+ return generate(model, prompt, raw, options, null);
+ }
-
- /**
- * Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api
- * 'api/chat'.
- *
- * @param model the ollama model to ask the question to
- * @param messages chat history / message stack to send to the model
- * @return {@link OllamaChatResult} containing the api response and the message history including the newly aqcuired assistant response.
- * @throws OllamaBaseException any response code than 200 has been returned
- * @throws IOException in case the responseStream can not be read
+ public OllamaToolsResult generateWithTools(String model, String prompt, boolean raw, Options options)
+ throws OllamaBaseException, IOException, InterruptedException {
+ OllamaToolsResult toolResult = new OllamaToolsResult();
+ Map toolResults = new HashMap<>();
+
+ OllamaResult result = generate(model, prompt, raw, options, null);
+ toolResult.setModelResult(result);
+
+ List toolDefs = Utils.getObjectMapper().readValue(result.getResponse(), Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolDef.class));
+ for (ToolDef toolDef : toolDefs) {
+ toolResults.put(toolDef, invokeTool(toolDef));
+ }
+ toolResult.setToolResults(toolResults);
+ return toolResult;
+ }
+
+
+ /**
+ * Generate response for a question to a model running on Ollama server and get a callback handle
+ * that can be used to check for status and get the response from the model later. This would be
+ * an async/non-blocking call.
+ *
+ * @param model the ollama model to ask the question to
+ * @param prompt the prompt/question text
+ * @return the ollama async result callback handle
+ */
+ public OllamaAsyncResultCallback generateAsync(String model, String prompt, boolean raw) {
+ OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
+ ollamaRequestModel.setRaw(raw);
+ URI uri = URI.create(this.host + "/api/generate");
+ OllamaAsyncResultCallback ollamaAsyncResultCallback =
+ new OllamaAsyncResultCallback(
+ getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
+ ollamaAsyncResultCallback.start();
+ return ollamaAsyncResultCallback;
+ }
+
+ /**
+ * With one or more image files, ask a question to a model running on Ollama server. This is a
+ * sync/blocking call.
+ *
+ * @param model the ollama model to ask the question to
+ * @param prompt the prompt/question text
+ * @param imageFiles the list of image files to use for the question
+ * @param options the Options object - More
+ * details on the options
+ * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
+ * @return OllamaResult that includes response text and time taken for response
+ */
+ public OllamaResult generateWithImageFiles(
+ String model, String prompt, List imageFiles, Options options, OllamaStreamHandler streamHandler)
+ throws OllamaBaseException, IOException, InterruptedException {
+ List images = new ArrayList<>();
+ for (File imageFile : imageFiles) {
+ images.add(encodeFileToBase64(imageFile));
+ }
+ OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
+ ollamaRequestModel.setOptions(options.getOptionsMap());
+ return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
+ }
+
+ /**
+ * Convenience method to call Ollama API without streaming responses.
+ *
+ * Uses {@link #generateWithImageFiles(String, String, List, Options, OllamaStreamHandler)}
+ */
+ public OllamaResult generateWithImageFiles(
+ String model, String prompt, List imageFiles, Options options)
+ throws OllamaBaseException, IOException, InterruptedException {
+ return generateWithImageFiles(model, prompt, imageFiles, options, null);
+ }
+
+ /**
+ * With one or more image URLs, ask a question to a model running on Ollama server. This is a
+ * sync/blocking call.
+ *
+ * @param model the ollama model to ask the question to
+ * @param prompt the prompt/question text
+ * @param imageURLs the list of image URLs to use for the question
+ * @param options the Options object - More
+ * details on the options
+ * @param streamHandler optional callback consumer that will be applied every time a streamed response is received. If not set, the stream parameter of the request is set to false.
+ * @return OllamaResult that includes response text and time taken for response
+ */
+ public OllamaResult generateWithImageURLs(
+ String model, String prompt, List imageURLs, Options options, OllamaStreamHandler streamHandler)
+ throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
+ List images = new ArrayList<>();
+ for (String imageURL : imageURLs) {
+ images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
+ }
+ OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
+ ollamaRequestModel.setOptions(options.getOptionsMap());
+ return generateSyncForOllamaRequestModel(ollamaRequestModel, streamHandler);
+ }
+
+ /**
+ * Convenience method to call Ollama API without streaming responses.
+ *
+ * Uses {@link #generateWithImageURLs(String, String, List, Options, OllamaStreamHandler)}
+ */
+ public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs,
+ Options options)
+ throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
+ return generateWithImageURLs(model, prompt, imageURLs, options, null);
+ }
+
+
+ /**
+ * Ask a question to a model based on a given message stack (i.e. a chat history). Creates a synchronous call to the api
+ * 'api/chat'.
+ *
+ * @param model the ollama model to ask the question to
+ * @param messages chat history / message stack to send to the model
+ * @return {@link OllamaChatResult} containing the api response and the message history including the newly aqcuired assistant response.
+ * @throws OllamaBaseException any response code than 200 has been returned
+ * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
- */
- public OllamaChatResult chat(String model, List messages) throws OllamaBaseException, IOException, InterruptedException{
- OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model);
- return chat(builder.withMessages(messages).build());
- }
-
- /**
- * Ask a question to a model using an {@link OllamaChatRequestModel}. This can be constructed using an {@link OllamaChatRequestBuilder}.
- *
- * Hint: the OllamaChatRequestModel#getStream() property is not implemented.
- *
- * @param request request object to be sent to the server
- * @return
- * @throws OllamaBaseException any response code than 200 has been returned
- * @throws IOException in case the responseStream can not be read
- * @throws InterruptedException in case the server is not reachable or network issues happen
- */
- public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{
- return chat(request,null);
- }
-
- /**
- * Ask a question to a model using an {@link OllamaChatRequestModel}. This can be constructed using an {@link OllamaChatRequestBuilder}.
- *
- * Hint: the OllamaChatRequestModel#getStream() property is not implemented.
- *
- * @param request request object to be sent to the server
- * @param streamHandler callback handler to handle the last message from stream (caution: all previous messages from stream will be concatenated)
- * @return
- * @throws OllamaBaseException any response code than 200 has been returned
- * @throws IOException in case the responseStream can not be read
- * @throws InterruptedException in case the server is not reachable or network issues happen
- */
- public OllamaChatResult chat(OllamaChatRequestModel request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException{
- OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
- OllamaResult result;
- if(streamHandler != null){
- request.setStream(true);
- result = requestCaller.call(request, streamHandler);
+ */
+ public OllamaChatResult chat(String model, List messages) throws OllamaBaseException, IOException, InterruptedException {
+ OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model);
+ return chat(builder.withMessages(messages).build());
}
- else {
- result = requestCaller.callSync(request);
+
+ /**
+ * Ask a question to a model using an {@link OllamaChatRequestModel}. This can be constructed using an {@link OllamaChatRequestBuilder}.
+ *
+ * Hint: the OllamaChatRequestModel#getStream() property is not implemented.
+ *
+ * @param request request object to be sent to the server
+ * @return
+ * @throws OllamaBaseException any response code than 200 has been returned
+ * @throws IOException in case the responseStream can not be read
+ * @throws InterruptedException in case the server is not reachable or network issues happen
+ */
+ public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException {
+ return chat(request, null);
}
- return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
- }
- // technical private methods //
-
- private static String encodeFileToBase64(File file) throws IOException {
- return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
- }
-
- private static String encodeByteArrayToBase64(byte[] bytes) {
- return Base64.getEncoder().encodeToString(bytes);
- }
-
- private OllamaResult generateSyncForOllamaRequestModel(
- OllamaGenerateRequestModel ollamaRequestModel, OllamaStreamHandler streamHandler)
- throws OllamaBaseException, IOException, InterruptedException {
- OllamaGenerateEndpointCaller requestCaller =
- new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
- OllamaResult result;
- if (streamHandler != null) {
- ollamaRequestModel.setStream(true);
- result = requestCaller.call(ollamaRequestModel, streamHandler);
- } else {
- result = requestCaller.callSync(ollamaRequestModel);
+ /**
+ * Ask a question to a model using an {@link OllamaChatRequestModel}. This can be constructed using an {@link OllamaChatRequestBuilder}.
+ *
+ * Hint: the OllamaChatRequestModel#getStream() property is not implemented.
+ *
+ * @param request request object to be sent to the server
+ * @param streamHandler callback handler to handle the last message from stream (caution: all previous messages from stream will be concatenated)
+ * @return
+ * @throws OllamaBaseException any response code than 200 has been returned
+ * @throws IOException in case the responseStream can not be read
+ * @throws InterruptedException in case the server is not reachable or network issues happen
+ */
+ public OllamaChatResult chat(OllamaChatRequestModel request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
+ OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
+ OllamaResult result;
+ if (streamHandler != null) {
+ request.setStream(true);
+ result = requestCaller.call(request, streamHandler);
+ } else {
+ result = requestCaller.callSync(request);
+ }
+ return new OllamaChatResult(result.getResponse(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
}
- return result;
- }
- /**
- * Get default request builder.
- *
- * @param uri URI to get a HttpRequest.Builder
- * @return HttpRequest.Builder
- */
- private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
- HttpRequest.Builder requestBuilder =
- HttpRequest.newBuilder(uri)
- .header("Content-Type", "application/json")
- .timeout(Duration.ofSeconds(requestTimeoutSeconds));
- if (isBasicAuthCredentialsSet()) {
- requestBuilder.header("Authorization", getBasicAuthHeaderValue());
+ // technical private methods //
+
+ private static String encodeFileToBase64(File file) throws IOException {
+ return Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
}
- return requestBuilder;
- }
- /**
- * Get basic authentication header value.
- *
- * @return basic authentication header value (encoded credentials)
- */
- private String getBasicAuthHeaderValue() {
- String credentialsToEncode = basicAuth.getUsername() + ":" + basicAuth.getPassword();
- return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
- }
+ private static String encodeByteArrayToBase64(byte[] bytes) {
+ return Base64.getEncoder().encodeToString(bytes);
+ }
- /**
- * Check if Basic Auth credentials set.
- *
- * @return true when Basic Auth credentials set
- */
- private boolean isBasicAuthCredentialsSet() {
- return basicAuth != null;
- }
+ private OllamaResult generateSyncForOllamaRequestModel(
+ OllamaGenerateRequestModel ollamaRequestModel, OllamaStreamHandler streamHandler)
+ throws OllamaBaseException, IOException, InterruptedException {
+ OllamaGenerateEndpointCaller requestCaller =
+ new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
+ OllamaResult result;
+ if (streamHandler != null) {
+ ollamaRequestModel.setStream(true);
+ result = requestCaller.call(ollamaRequestModel, streamHandler);
+ } else {
+ result = requestCaller.callSync(ollamaRequestModel);
+ }
+ return result;
+ }
+
+ /**
+ * Get default request builder.
+ *
+ * @param uri URI to get a HttpRequest.Builder
+ * @return HttpRequest.Builder
+ */
+ private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
+ HttpRequest.Builder requestBuilder =
+ HttpRequest.newBuilder(uri)
+ .header("Content-Type", "application/json")
+ .timeout(Duration.ofSeconds(requestTimeoutSeconds));
+ if (isBasicAuthCredentialsSet()) {
+ requestBuilder.header("Authorization", getBasicAuthHeaderValue());
+ }
+ return requestBuilder;
+ }
+
+ /**
+ * Get basic authentication header value.
+ *
+ * @return basic authentication header value (encoded credentials)
+ */
+ private String getBasicAuthHeaderValue() {
+ String credentialsToEncode = basicAuth.getUsername() + ":" + basicAuth.getPassword();
+ return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
+ }
+
+ /**
+ * Check if Basic Auth credentials set.
+ *
+ * @return true when Basic Auth credentials set
+ */
+ private boolean isBasicAuthCredentialsSet() {
+ return basicAuth != null;
+ }
+
+
+ public void registerTool(MistralTools.ToolSpecification toolSpecification) {
+ ToolRegistry.addFunction(toolSpecification.getFunctionName(), toolSpecification.getToolDefinition());
+ }
+
+ private Object invokeTool(ToolDef toolDef) {
+ try {
+ String methodName = toolDef.getName();
+ Map arguments = toolDef.getArguments();
+ DynamicFunction function = ToolRegistry.getFunction(methodName);
+ if (function == null) {
+ throw new IllegalArgumentException("No such tool: " + methodName);
+ }
+ return function.apply(arguments);
+ } catch (Exception e) {
+ e.printStackTrace();
+ return "Error calling tool: " + e.getMessage();
+ }
+ }
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/impl/ConsoleOutputStreamHandler.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/impl/ConsoleOutputStreamHandler.java
new file mode 100644
index 0000000..6807019
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/impl/ConsoleOutputStreamHandler.java
@@ -0,0 +1,14 @@
+package io.github.amithkoujalgi.ollama4j.core.impl;
+
+import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
+
+public class ConsoleOutputStreamHandler implements OllamaStreamHandler {
+ private final StringBuffer response = new StringBuffer();
+
+ @Override
+ public void accept(String message) {
+ String substr = message.substring(response.length());
+ response.append(substr);
+ System.out.print(substr);
+ }
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java
index 27fd3e5..15efd70 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java
@@ -1,5 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.models;
+import java.time.LocalDateTime;
+import java.time.OffsetDateTime;
+
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
@@ -11,7 +14,9 @@ public class Model {
private String name;
private String model;
@JsonProperty("modified_at")
- private String modifiedAt;
+ private OffsetDateTime modifiedAt;
+ @JsonProperty("expires_at")
+ private OffsetDateTime expiresAt;
private String digest;
private long size;
@JsonProperty("details")
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResponseModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResponseModel.java
index 4d0b027..418338f 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResponseModel.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatResponseModel.java
@@ -1,14 +1,15 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat;
import com.fasterxml.jackson.annotation.JsonProperty;
+import lombok.Data;
import java.util.List;
-import lombok.Data;
@Data
public class OllamaChatResponseModel {
private String model;
private @JsonProperty("created_at") String createdAt;
+ private @JsonProperty("done_reason") String doneReason;
private OllamaChatMessage message;
private boolean done;
private String error;
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingResponseModel.java
similarity index 65%
rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java
rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingResponseModel.java
index e3040a2..85dba31 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/EmbeddingResponse.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingResponseModel.java
@@ -1,4 +1,4 @@
-package io.github.amithkoujalgi.ollama4j.core.models;
+package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
import com.fasterxml.jackson.annotation.JsonProperty;
@@ -7,7 +7,7 @@ import lombok.Data;
@SuppressWarnings("unused")
@Data
-public class EmbeddingResponse {
+public class OllamaEmbeddingResponseModel {
@JsonProperty("embedding")
private List embedding;
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestBuilder.java
new file mode 100644
index 0000000..ef7a84e
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestBuilder.java
@@ -0,0 +1,31 @@
+package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
+
+import io.github.amithkoujalgi.ollama4j.core.utils.Options;
+
+public class OllamaEmbeddingsRequestBuilder {
+
+ private OllamaEmbeddingsRequestBuilder(String model, String prompt){
+ request = new OllamaEmbeddingsRequestModel(model, prompt);
+ }
+
+ private OllamaEmbeddingsRequestModel request;
+
+ public static OllamaEmbeddingsRequestBuilder getInstance(String model, String prompt){
+ return new OllamaEmbeddingsRequestBuilder(model, prompt);
+ }
+
+ public OllamaEmbeddingsRequestModel build(){
+ return request;
+ }
+
+ public OllamaEmbeddingsRequestBuilder withOptions(Options options){
+ this.request.setOptions(options.getOptionsMap());
+ return this;
+ }
+
+ public OllamaEmbeddingsRequestBuilder withKeepAlive(String keepAlive){
+ this.request.setKeepAlive(keepAlive);
+ return this;
+ }
+
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestModel.java
new file mode 100644
index 0000000..a369124
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/embeddings/OllamaEmbeddingsRequestModel.java
@@ -0,0 +1,33 @@
+package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
+
+import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
+import java.util.Map;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.NonNull;
+import lombok.RequiredArgsConstructor;
+
+@Data
+@RequiredArgsConstructor
+@NoArgsConstructor
+public class OllamaEmbeddingsRequestModel {
+ @NonNull
+ private String model;
+ @NonNull
+ private String prompt;
+
+ protected Map options;
+ @JsonProperty(value = "keep_alive")
+ private String keepAlive;
+
+ @Override
+ public String toString() {
+ try {
+ return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java
deleted file mode 100644
index 1455a94..0000000
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/ModelEmbeddingsRequest.java
+++ /dev/null
@@ -1,23 +0,0 @@
-package io.github.amithkoujalgi.ollama4j.core.models.request;
-
-import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
-
-import com.fasterxml.jackson.core.JsonProcessingException;
-import lombok.AllArgsConstructor;
-import lombok.Data;
-
-@Data
-@AllArgsConstructor
-public class ModelEmbeddingsRequest {
- private String model;
- private String prompt;
-
- @Override
- public String toString() {
- try {
- return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
- } catch (JsonProcessingException e) {
- throw new RuntimeException(e);
- }
- }
-}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java
index 811ef11..cc6c7f8 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java
@@ -1,12 +1,6 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
-import java.io.IOException;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
import com.fasterxml.jackson.core.JsonProcessingException;
-
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
@@ -15,11 +9,15 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
/**
* Specialization class for requests
*/
-public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
+public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
@@ -39,14 +37,14 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
try {
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
responseBuffer.append(ollamaResponseModel.getMessage().getContent());
- if(streamObserver != null) {
+ if (streamObserver != null) {
streamObserver.notify(ollamaResponseModel);
}
return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) {
- LOG.error("Error parsing the Ollama chat response!",e);
+ LOG.error("Error parsing the Ollama chat response!", e);
return true;
- }
+ }
}
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
@@ -54,7 +52,4 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
streamObserver = new OllamaChatStreamObserver(streamHandler);
return super.callSync(body);
}
-
-
-
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java
index ad8d5bb..350200a 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java
@@ -1,5 +1,15 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
+import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
+import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
+import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
+import io.github.amithkoujalgi.ollama4j.core.models.OllamaErrorResponseModel;
+import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
+import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
+import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
@@ -12,22 +22,11 @@ import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
-import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
-import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
-import io.github.amithkoujalgi.ollama4j.core.models.OllamaErrorResponseModel;
-import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
-import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
-import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
-
/**
* Abstract helperclass to call the ollama api server.
*/
public abstract class OllamaEndpointCaller {
-
+
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
private String host;
@@ -49,107 +48,105 @@ public abstract class OllamaEndpointCaller {
/**
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
- *
+ *
* @param body POST body payload
* @return result answer given by the assistant
- * @throws OllamaBaseException any response code than 200 has been returned
- * @throws IOException in case the responseStream can not be read
+ * @throws OllamaBaseException any response code than 200 has been returned
+ * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen
*/
- public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{
-
+ public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
// Create Request
- long startTime = System.currentTimeMillis();
- HttpClient httpClient = HttpClient.newHttpClient();
- URI uri = URI.create(this.host + getEndpointSuffix());
- HttpRequest.Builder requestBuilder =
- getRequestBuilderDefault(uri)
- .POST(
- body.getBodyPublisher());
- HttpRequest request = requestBuilder.build();
- if (this.verbose) LOG.info("Asking model: " + body.toString());
- HttpResponse response =
- httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
-
-
+ long startTime = System.currentTimeMillis();
+ HttpClient httpClient = HttpClient.newHttpClient();
+ URI uri = URI.create(this.host + getEndpointSuffix());
+ HttpRequest.Builder requestBuilder =
+ getRequestBuilderDefault(uri)
+ .POST(
+ body.getBodyPublisher());
+ HttpRequest request = requestBuilder.build();
+ if (this.verbose) LOG.info("Asking model: " + body.toString());
+ HttpResponse response =
+ httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
+
int statusCode = response.statusCode();
- InputStream responseBodyStream = response.body();
- StringBuilder responseBuffer = new StringBuilder();
- try (BufferedReader reader =
- new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
- String line;
- while ((line = reader.readLine()) != null) {
- if (statusCode == 404) {
- LOG.warn("Status code: 404 (Not Found)");
- OllamaErrorResponseModel ollamaResponseModel =
- Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
- responseBuffer.append(ollamaResponseModel.getError());
- } else if (statusCode == 401) {
- LOG.warn("Status code: 401 (Unauthorized)");
- OllamaErrorResponseModel ollamaResponseModel =
- Utils.getObjectMapper()
- .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class);
- responseBuffer.append(ollamaResponseModel.getError());
- } else if (statusCode == 400) {
- LOG.warn("Status code: 400 (Bad Request)");
- OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line,
- OllamaErrorResponseModel.class);
- responseBuffer.append(ollamaResponseModel.getError());
- } else {
- boolean finished = parseResponseAndAddToBuffer(line,responseBuffer);
- if (finished) {
- break;
+ InputStream responseBodyStream = response.body();
+ StringBuilder responseBuffer = new StringBuilder();
+ try (BufferedReader reader =
+ new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ if (statusCode == 404) {
+ LOG.warn("Status code: 404 (Not Found)");
+ OllamaErrorResponseModel ollamaResponseModel =
+ Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
+ responseBuffer.append(ollamaResponseModel.getError());
+ } else if (statusCode == 401) {
+ LOG.warn("Status code: 401 (Unauthorized)");
+ OllamaErrorResponseModel ollamaResponseModel =
+ Utils.getObjectMapper()
+ .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class);
+ responseBuffer.append(ollamaResponseModel.getError());
+ } else if (statusCode == 400) {
+ LOG.warn("Status code: 400 (Bad Request)");
+ OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line,
+ OllamaErrorResponseModel.class);
+ responseBuffer.append(ollamaResponseModel.getError());
+ } else {
+ boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
+ if (finished) {
+ break;
+ }
+ }
}
}
- }
- }
- if (statusCode != 200) {
- LOG.error("Status code " + statusCode);
- throw new OllamaBaseException(responseBuffer.toString());
- } else {
- long endTime = System.currentTimeMillis();
- OllamaResult ollamaResult =
- new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
- if (verbose) LOG.info("Model response: " + ollamaResult);
- return ollamaResult;
+ if (statusCode != 200) {
+ LOG.error("Status code " + statusCode);
+ throw new OllamaBaseException(responseBuffer.toString());
+ } else {
+ long endTime = System.currentTimeMillis();
+ OllamaResult ollamaResult =
+ new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
+ if (verbose) LOG.info("Model response: " + ollamaResult);
+ return ollamaResult;
+ }
}
- }
/**
- * Get default request builder.
- *
- * @param uri URI to get a HttpRequest.Builder
- * @return HttpRequest.Builder
- */
- private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
- HttpRequest.Builder requestBuilder =
- HttpRequest.newBuilder(uri)
- .header("Content-Type", "application/json")
- .timeout(Duration.ofSeconds(this.requestTimeoutSeconds));
- if (isBasicAuthCredentialsSet()) {
- requestBuilder.header("Authorization", getBasicAuthHeaderValue());
+ * Get default request builder.
+ *
+ * @param uri URI to get a HttpRequest.Builder
+ * @return HttpRequest.Builder
+ */
+ private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
+ HttpRequest.Builder requestBuilder =
+ HttpRequest.newBuilder(uri)
+ .header("Content-Type", "application/json")
+ .timeout(Duration.ofSeconds(this.requestTimeoutSeconds));
+ if (isBasicAuthCredentialsSet()) {
+ requestBuilder.header("Authorization", getBasicAuthHeaderValue());
+ }
+ return requestBuilder;
}
- return requestBuilder;
- }
- /**
- * Get basic authentication header value.
- *
- * @return basic authentication header value (encoded credentials)
- */
- private String getBasicAuthHeaderValue() {
- String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
- return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
- }
+ /**
+ * Get basic authentication header value.
+ *
+ * @return basic authentication header value (encoded credentials)
+ */
+ private String getBasicAuthHeaderValue() {
+ String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
+ return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
+ }
+
+ /**
+ * Check if Basic Auth credentials set.
+ *
+ * @return true when Basic Auth credentials set
+ */
+ private boolean isBasicAuthCredentialsSet() {
+ return this.basicAuth != null;
+ }
- /**
- * Check if Basic Auth credentials set.
- *
- * @return true when Basic Auth credentials set
- */
- private boolean isBasicAuthCredentialsSet() {
- return this.basicAuth != null;
- }
-
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
index fe7fbec..d3d71e4 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java
@@ -1,9 +1,5 @@
package io.github.amithkoujalgi.ollama4j.core.models.request;
-import java.io.IOException;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
@@ -13,15 +9,19 @@ import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRespo
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
+import java.io.IOException;
+
+public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class);
private OllamaGenerateStreamObserver streamObserver;
public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
- super(host, basicAuth, requestTimeoutSeconds, verbose);
+ super(host, basicAuth, requestTimeoutSeconds, verbose);
}
@Override
@@ -31,24 +31,22 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
@Override
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
- try {
- OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
- responseBuffer.append(ollamaResponseModel.getResponse());
- if(streamObserver != null) {
- streamObserver.notify(ollamaResponseModel);
- }
- return ollamaResponseModel.isDone();
- } catch (JsonProcessingException e) {
- LOG.error("Error parsing the Ollama chat response!",e);
- return true;
- }
+ try {
+ OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
+ responseBuffer.append(ollamaResponseModel.getResponse());
+ if (streamObserver != null) {
+ streamObserver.notify(ollamaResponseModel);
+ }
+ return ollamaResponseModel.isDone();
+ } catch (JsonProcessingException e) {
+ LOG.error("Error parsing the Ollama chat response!", e);
+ return true;
+ }
}
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
- throws OllamaBaseException, IOException, InterruptedException {
- streamObserver = new OllamaGenerateStreamObserver(streamHandler);
- return super.callSync(body);
+ throws OllamaBaseException, IOException, InterruptedException {
+ streamObserver = new OllamaGenerateStreamObserver(streamHandler);
+ return super.callSync(body);
}
-
-
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/DynamicFunction.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/DynamicFunction.java
new file mode 100644
index 0000000..5b8f5e6
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/DynamicFunction.java
@@ -0,0 +1,8 @@
+package io.github.amithkoujalgi.ollama4j.core.tools;
+
+import java.util.Map;
+
+@FunctionalInterface
+public interface DynamicFunction {
+ Object apply(Map arguments);
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/MistralTools.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/MistralTools.java
new file mode 100644
index 0000000..fff8071
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/MistralTools.java
@@ -0,0 +1,139 @@
+package io.github.amithkoujalgi.ollama4j.core.tools;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
+import lombok.Builder;
+import lombok.Data;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class MistralTools {
+ @Data
+ @Builder
+ public static class ToolSpecification {
+ private String functionName;
+ private String functionDesc;
+ private Map props;
+ private DynamicFunction toolDefinition;
+ }
+
+ @Data
+ @JsonIgnoreProperties(ignoreUnknown = true)
+ public static class PromptFuncDefinition {
+ private String type;
+ private PromptFuncSpec function;
+
+ @Data
+ public static class PromptFuncSpec {
+ private String name;
+ private String description;
+ private Parameters parameters;
+ }
+
+ @Data
+ public static class Parameters {
+ private String type;
+ private Map properties;
+ private List required;
+ }
+
+ @Data
+ @Builder
+ public static class Property {
+ private String type;
+ private String description;
+ @JsonProperty("enum")
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ private List enumValues;
+ @JsonIgnore
+ private boolean required;
+ }
+ }
+
+ public static class PropsBuilder {
+ private final Map props = new HashMap<>();
+
+ public PropsBuilder withProperty(String key, PromptFuncDefinition.Property property) {
+ props.put(key, property);
+ return this;
+ }
+
+ public Map build() {
+ return props;
+ }
+ }
+
+ public static class PromptBuilder {
+ private final List tools = new ArrayList<>();
+
+ private String promptText;
+
+ public String build() throws JsonProcessingException {
+ return "[AVAILABLE_TOOLS] " + Utils.getObjectMapper().writeValueAsString(tools) + "[/AVAILABLE_TOOLS][INST] " + promptText + " [/INST]";
+ }
+
+ public PromptBuilder withPrompt(String prompt) throws JsonProcessingException {
+ promptText = prompt;
+ return this;
+ }
+
+ public PromptBuilder withToolSpecification(ToolSpecification spec) {
+ PromptFuncDefinition def = new PromptFuncDefinition();
+ def.setType("function");
+
+ PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
+ functionDetail.setName(spec.getFunctionName());
+ functionDetail.setDescription(spec.getFunctionDesc());
+
+ PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
+ parameters.setType("object");
+ parameters.setProperties(spec.getProps());
+
+ List requiredValues = new ArrayList<>();
+ for (Map.Entry p : spec.getProps().entrySet()) {
+ if (p.getValue().isRequired()) {
+ requiredValues.add(p.getKey());
+ }
+ }
+ parameters.setRequired(requiredValues);
+ functionDetail.setParameters(parameters);
+ def.setFunction(functionDetail);
+
+ tools.add(def);
+ return this;
+ }
+//
+// public PromptBuilder withToolSpecification(String functionName, String functionDesc, Map props) {
+// PromptFuncDefinition def = new PromptFuncDefinition();
+// def.setType("function");
+//
+// PromptFuncDefinition.PromptFuncSpec functionDetail = new PromptFuncDefinition.PromptFuncSpec();
+// functionDetail.setName(functionName);
+// functionDetail.setDescription(functionDesc);
+//
+// PromptFuncDefinition.Parameters parameters = new PromptFuncDefinition.Parameters();
+// parameters.setType("object");
+// parameters.setProperties(props);
+//
+// List requiredValues = new ArrayList<>();
+// for (Map.Entry p : props.entrySet()) {
+// if (p.getValue().isRequired()) {
+// requiredValues.add(p.getKey());
+// }
+// }
+// parameters.setRequired(requiredValues);
+// functionDetail.setParameters(parameters);
+// def.setFunction(functionDetail);
+//
+// tools.add(def);
+// return this;
+// }
+ }
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java
new file mode 100644
index 0000000..65ef3ac
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/OllamaToolsResult.java
@@ -0,0 +1,16 @@
+package io.github.amithkoujalgi.ollama4j.core.tools;
+
+import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.Map;
+
+@Data
+@NoArgsConstructor
+@AllArgsConstructor
+public class OllamaToolsResult {
+ private OllamaResult modelResult;
+ private Map toolResults;
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolDef.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolDef.java
new file mode 100644
index 0000000..751d186
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolDef.java
@@ -0,0 +1,18 @@
+package io.github.amithkoujalgi.ollama4j.core.tools;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.Map;
+
+@Data
+@AllArgsConstructor
+@NoArgsConstructor
+public class ToolDef {
+
+ private String name;
+ private Map arguments;
+
+}
+
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java
new file mode 100644
index 0000000..0004c7f
--- /dev/null
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/tools/ToolRegistry.java
@@ -0,0 +1,17 @@
+package io.github.amithkoujalgi.ollama4j.core.tools;
+
+import java.util.HashMap;
+import java.util.Map;
+
+public class ToolRegistry {
+ private static final Map functionMap = new HashMap<>();
+
+
+ public static DynamicFunction getFunction(String name) {
+ return functionMap.get(name);
+ }
+
+ public static void addFunction(String name, DynamicFunction function) {
+ functionMap.put(name, function);
+ }
+}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java
index 96bcc43..2613d86 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java
@@ -8,57 +8,81 @@ package io.github.amithkoujalgi.ollama4j.core.types;
*/
@SuppressWarnings("ALL")
public class OllamaModelType {
- public static final String LLAMA2 = "llama2";
- public static final String MISTRAL = "mistral";
- public static final String LLAVA = "llava";
- public static final String MIXTRAL = "mixtral";
- public static final String STARLING_LM = "starling-lm";
- public static final String NEURAL_CHAT = "neural-chat";
- public static final String CODELLAMA = "codellama";
- public static final String LLAMA2_UNCENSORED = "llama2-uncensored";
- public static final String DOLPHIN_MIXTRAL = "dolphin-mixtral";
- public static final String ORCA_MINI = "orca-mini";
- public static final String VICUNA = "vicuna";
- public static final String WIZARD_VICUNA_UNCENSORED = "wizard-vicuna-uncensored";
- public static final String PHIND_CODELLAMA = "phind-codellama";
- public static final String PHI = "phi";
- public static final String ZEPHYR = "zephyr";
- public static final String WIZARDCODER = "wizardcoder";
- public static final String MISTRAL_OPENORCA = "mistral-openorca";
- public static final String NOUS_HERMES = "nous-hermes";
- public static final String DEEPSEEK_CODER = "deepseek-coder";
- public static final String WIZARD_MATH = "wizard-math";
- public static final String LLAMA2_CHINESE = "llama2-chinese";
- public static final String FALCON = "falcon";
- public static final String ORCA2 = "orca2";
- public static final String STABLE_BELUGA = "stable-beluga";
- public static final String CODEUP = "codeup";
- public static final String EVERYTHINGLM = "everythinglm";
- public static final String MEDLLAMA2 = "medllama2";
- public static final String WIZARDLM_UNCENSORED = "wizardlm-uncensored";
- public static final String STARCODER = "starcoder";
- public static final String DOLPHIN22_MISTRAL = "dolphin2.2-mistral";
- public static final String OPENCHAT = "openchat";
- public static final String WIZARD_VICUNA = "wizard-vicuna";
- public static final String OPENHERMES25_MISTRAL = "openhermes2.5-mistral";
- public static final String OPEN_ORCA_PLATYPUS2 = "open-orca-platypus2";
- public static final String YI = "yi";
- public static final String YARN_MISTRAL = "yarn-mistral";
- public static final String SAMANTHA_MISTRAL = "samantha-mistral";
- public static final String SQLCODER = "sqlcoder";
- public static final String YARN_LLAMA2 = "yarn-llama2";
- public static final String MEDITRON = "meditron";
- public static final String STABLELM_ZEPHYR = "stablelm-zephyr";
- public static final String OPENHERMES2_MISTRAL = "openhermes2-mistral";
- public static final String DEEPSEEK_LLM = "deepseek-llm";
- public static final String MISTRALLITE = "mistrallite";
- public static final String DOLPHIN21_MISTRAL = "dolphin2.1-mistral";
- public static final String WIZARDLM = "wizardlm";
- public static final String CODEBOOGA = "codebooga";
- public static final String MAGICODER = "magicoder";
- public static final String GOLIATH = "goliath";
- public static final String NEXUSRAVEN = "nexusraven";
- public static final String ALFRED = "alfred";
- public static final String XWINLM = "xwinlm";
- public static final String BAKLLAVA = "bakllava";
+ public static final String GEMMA = "gemma";
+ public static final String GEMMA2 = "gemma2";
+
+
+ public static final String LLAMA2 = "llama2";
+ public static final String LLAMA3 = "llama3";
+ public static final String MISTRAL = "mistral";
+ public static final String MIXTRAL = "mixtral";
+ public static final String LLAVA = "llava";
+ public static final String LLAVA_PHI3 = "llava-phi3";
+ public static final String NEURAL_CHAT = "neural-chat";
+ public static final String CODELLAMA = "codellama";
+ public static final String DOLPHIN_MIXTRAL = "dolphin-mixtral";
+ public static final String MISTRAL_OPENORCA = "mistral-openorca";
+ public static final String LLAMA2_UNCENSORED = "llama2-uncensored";
+ public static final String PHI = "phi";
+ public static final String PHI3 = "phi3";
+ public static final String ORCA_MINI = "orca-mini";
+ public static final String DEEPSEEK_CODER = "deepseek-coder";
+ public static final String DOLPHIN_MISTRAL = "dolphin-mistral";
+ public static final String VICUNA = "vicuna";
+ public static final String WIZARD_VICUNA_UNCENSORED = "wizard-vicuna-uncensored";
+ public static final String ZEPHYR = "zephyr";
+ public static final String OPENHERMES = "openhermes";
+ public static final String QWEN = "qwen";
+
+ public static final String QWEN2 = "qwen2";
+ public static final String WIZARDCODER = "wizardcoder";
+ public static final String LLAMA2_CHINESE = "llama2-chinese";
+ public static final String TINYLLAMA = "tinyllama";
+ public static final String PHIND_CODELLAMA = "phind-codellama";
+ public static final String OPENCHAT = "openchat";
+ public static final String ORCA2 = "orca2";
+ public static final String FALCON = "falcon";
+ public static final String WIZARD_MATH = "wizard-math";
+ public static final String TINYDOLPHIN = "tinydolphin";
+ public static final String NOUS_HERMES = "nous-hermes";
+ public static final String YI = "yi";
+ public static final String DOLPHIN_PHI = "dolphin-phi";
+ public static final String STARLING_LM = "starling-lm";
+ public static final String STARCODER = "starcoder";
+ public static final String CODEUP = "codeup";
+ public static final String MEDLLAMA2 = "medllama2";
+ public static final String STABLE_CODE = "stable-code";
+ public static final String WIZARDLM_UNCENSORED = "wizardlm-uncensored";
+ public static final String BAKLLAVA = "bakllava";
+ public static final String EVERYTHINGLM = "everythinglm";
+ public static final String SOLAR = "solar";
+ public static final String STABLE_BELUGA = "stable-beluga";
+ public static final String SQLCODER = "sqlcoder";
+ public static final String YARN_MISTRAL = "yarn-mistral";
+ public static final String NOUS_HERMES2_MIXTRAL = "nous-hermes2-mixtral";
+ public static final String SAMANTHA_MISTRAL = "samantha-mistral";
+ public static final String STABLELM_ZEPHYR = "stablelm-zephyr";
+ public static final String MEDITRON = "meditron";
+ public static final String WIZARD_VICUNA = "wizard-vicuna";
+ public static final String STABLELM2 = "stablelm2";
+ public static final String MAGICODER = "magicoder";
+ public static final String YARN_LLAMA2 = "yarn-llama2";
+ public static final String NOUS_HERMES2 = "nous-hermes2";
+ public static final String DEEPSEEK_LLM = "deepseek-llm";
+ public static final String LLAMA_PRO = "llama-pro";
+ public static final String OPEN_ORCA_PLATYPUS2 = "open-orca-platypus2";
+ public static final String CODEBOOGA = "codebooga";
+ public static final String MISTRALLITE = "mistrallite";
+ public static final String NEXUSRAVEN = "nexusraven";
+ public static final String GOLIATH = "goliath";
+ public static final String NOMIC_EMBED_TEXT = "nomic-embed-text";
+ public static final String NOTUX = "notux";
+ public static final String ALFRED = "alfred";
+ public static final String MEGADOLPHIN = "megadolphin";
+ public static final String WIZARDLM = "wizardlm";
+ public static final String XWINLM = "xwinlm";
+ public static final String NOTUS = "notus";
+ public static final String DUCKDB_NSQL = "duckdb-nsql";
+ public static final String ALL_MINILM = "all-minilm";
+ public static final String CODESTRAL = "codestral";
}
diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Utils.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Utils.java
index 1504c1d..96b07ae 100644
--- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Utils.java
+++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Utils.java
@@ -8,10 +8,18 @@ import java.net.URISyntaxException;
import java.net.URL;
import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
public class Utils {
+
+ private static ObjectMapper objectMapper;
+
public static ObjectMapper getObjectMapper() {
- return new ObjectMapper();
+ if(objectMapper == null) {
+ objectMapper = new ObjectMapper();
+ objectMapper.registerModule(new JavaTimeModule());
+ }
+ return objectMapper;
}
public static byte[] loadImageBytesFromUrl(String imageUrl)
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
index dc91287..58e55a1 100644
--- a/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/integrationtests/TestRealAPIs.java
@@ -1,7 +1,5 @@
package io.github.amithkoujalgi.ollama4j.integrationtests;
-import static org.junit.jupiter.api.Assertions.*;
-
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
@@ -10,7 +8,16 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
+import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
+import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
+import lombok.Data;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Order;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
@@ -20,355 +27,369 @@ import java.net.http.HttpConnectTimeoutException;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
-import lombok.Data;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Order;
-import org.junit.jupiter.api.Test;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+
+import static org.junit.jupiter.api.Assertions.*;
class TestRealAPIs {
- private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
+ private static final Logger LOG = LoggerFactory.getLogger(TestRealAPIs.class);
- OllamaAPI ollamaAPI;
- Config config;
+ OllamaAPI ollamaAPI;
+ Config config;
- private File getImageFileFromClasspath(String fileName) {
- ClassLoader classLoader = getClass().getClassLoader();
- return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
- }
-
- @BeforeEach
- void setUp() {
- config = new Config();
- ollamaAPI = new OllamaAPI(config.getOllamaURL());
- ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
- }
-
- @Test
- @Order(1)
- void testWrongEndpoint() {
- OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
- assertThrows(ConnectException.class, ollamaAPI::listModels);
- }
-
- @Test
- @Order(1)
- void testEndpointReachability() {
- try {
- assertNotNull(ollamaAPI.listModels());
- } catch (HttpConnectTimeoutException e) {
- fail(e.getMessage());
- } catch (Exception e) {
- throw new RuntimeException(e);
+ private File getImageFileFromClasspath(String fileName) {
+ ClassLoader classLoader = getClass().getClassLoader();
+ return new File(Objects.requireNonNull(classLoader.getResource(fileName)).getFile());
}
- }
- @Test
- @Order(2)
- void testListModels() {
- testEndpointReachability();
- try {
- assertNotNull(ollamaAPI.listModels());
- ollamaAPI.listModels().forEach(System.out::println);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @BeforeEach
+ void setUp() {
+ config = new Config();
+ ollamaAPI = new OllamaAPI(config.getOllamaURL());
+ ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
}
- }
- @Test
- @Order(2)
- void testPullModel() {
- testEndpointReachability();
- try {
- ollamaAPI.pullModel(config.getModel());
- boolean found =
- ollamaAPI.listModels().stream()
- .anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
- assertTrue(found);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ @Order(1)
+ void testWrongEndpoint() {
+ OllamaAPI ollamaAPI = new OllamaAPI("http://wrong-host:11434");
+ assertThrows(ConnectException.class, ollamaAPI::listModels);
}
- }
- @Test
- @Order(3)
- void testListDtails() {
- testEndpointReachability();
- try {
- ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel());
- assertNotNull(modelDetails);
- System.out.println(modelDetails);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ @Order(1)
+ void testEndpointReachability() {
+ try {
+ assertNotNull(ollamaAPI.listModels());
+ } catch (HttpConnectTimeoutException e) {
+ fail(e.getMessage());
+ } catch (Exception e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testAskModelWithDefaultOptions() {
- testEndpointReachability();
- try {
- OllamaResult result =
- ollamaAPI.generate(
- config.getModel(),
- "What is the capital of France? And what's France's connection with Mona Lisa?",
- new OptionsBuilder().build());
- assertNotNull(result);
- assertNotNull(result.getResponse());
- assertFalse(result.getResponse().isEmpty());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ @Test
+ @Order(2)
+ void testListModels() {
+ testEndpointReachability();
+ try {
+ assertNotNull(ollamaAPI.listModels());
+ ollamaAPI.listModels().forEach(System.out::println);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testAskModelWithDefaultOptionsStreamed() {
- testEndpointReachability();
- try {
-
- StringBuffer sb = new StringBuffer("");
-
- OllamaResult result = ollamaAPI.generate(config.getModel(),
- "What is the capital of France? And what's France's connection with Mona Lisa?",
- new OptionsBuilder().build(), (s) -> {
- LOG.info(s);
- String substring = s.substring(sb.toString().length(), s.length());
- LOG.info(substring);
- sb.append(substring);
- });
-
- assertNotNull(result);
- assertNotNull(result.getResponse());
- assertFalse(result.getResponse().isEmpty());
- assertEquals(sb.toString().trim(), result.getResponse().trim());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ @Test
+ @Order(2)
+ void testPullModel() {
+ testEndpointReachability();
+ try {
+ ollamaAPI.pullModel(config.getModel());
+ boolean found =
+ ollamaAPI.listModels().stream()
+ .anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
+ assertTrue(found);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testAskModelWithOptions() {
- testEndpointReachability();
- try {
- OllamaResult result =
- ollamaAPI.generate(
- config.getModel(),
- "What is the capital of France? And what's France's connection with Mona Lisa?",
- new OptionsBuilder().setTemperature(0.9f).build());
- assertNotNull(result);
- assertNotNull(result.getResponse());
- assertFalse(result.getResponse().isEmpty());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ @Test
+ @Order(3)
+ void testListDtails() {
+ testEndpointReachability();
+ try {
+ ModelDetail modelDetails = ollamaAPI.getModelDetails(config.getModel());
+ assertNotNull(modelDetails);
+ System.out.println(modelDetails);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testChat() {
- testEndpointReachability();
- try {
- OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
- OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
- .withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
- .withMessage(OllamaChatMessageRole.USER,"And what is the second larges city?")
- .build();
-
- OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
- assertNotNull(chatResult);
- assertFalse(chatResult.getResponse().isBlank());
- assertEquals(4,chatResult.getChatHistory().size());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ @Test
+ @Order(3)
+ void testAskModelWithDefaultOptions() {
+ testEndpointReachability();
+ try {
+ OllamaResult result =
+ ollamaAPI.generate(
+ config.getModel(),
+ "What is the capital of France? And what's France's connection with Mona Lisa?",
+ false,
+ new OptionsBuilder().build());
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testChatWithSystemPrompt() {
- testEndpointReachability();
- try {
- OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
- OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
- "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
- .withMessage(OllamaChatMessageRole.USER,
- "What is the capital of France? And what's France's connection with Mona Lisa?")
- .build();
+ @Test
+ @Order(3)
+ void testAskModelWithDefaultOptionsStreamed() {
+ testEndpointReachability();
+ try {
+ StringBuffer sb = new StringBuffer("");
+ OllamaResult result = ollamaAPI.generate(config.getModel(),
+ "What is the capital of France? And what's France's connection with Mona Lisa?",
+ false,
+ new OptionsBuilder().build(), (s) -> {
+ LOG.info(s);
+ String substring = s.substring(sb.toString().length(), s.length());
+ LOG.info(substring);
+ sb.append(substring);
+ });
- OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
- assertNotNull(chatResult);
- assertFalse(chatResult.getResponse().isBlank());
- assertTrue(chatResult.getResponse().startsWith("NI"));
- assertEquals(3, chatResult.getChatHistory().size());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ assertEquals(sb.toString().trim(), result.getResponse().trim());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testChatWithStream() {
- testEndpointReachability();
- try {
- OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
- OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
- "What is the capital of France? And what's France's connection with Mona Lisa?")
- .build();
-
- StringBuffer sb = new StringBuffer("");
-
- OllamaChatResult chatResult = ollamaAPI.chat(requestModel,(s) -> {
- LOG.info(s);
- String substring = s.substring(sb.toString().length(), s.length());
- LOG.info(substring);
- sb.append(substring);
- });
- assertNotNull(chatResult);
- assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ @Test
+ @Order(3)
+ void testAskModelWithOptions() {
+ testEndpointReachability();
+ try {
+ OllamaResult result =
+ ollamaAPI.generate(
+ config.getModel(),
+ "What is the capital of France? And what's France's connection with Mona Lisa?",
+ true,
+ new OptionsBuilder().setTemperature(0.9f).build());
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testChatWithImageFromFileWithHistoryRecognition() {
- testEndpointReachability();
- try {
- OllamaChatRequestBuilder builder =
- OllamaChatRequestBuilder.getInstance(config.getImageModel());
- OllamaChatRequestModel requestModel =
- builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
- List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
+ @Test
+ @Order(3)
+ void testChat() {
+ testEndpointReachability();
+ try {
+ OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
+ OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France?")
+ .withMessage(OllamaChatMessageRole.ASSISTANT, "Should be Paris!")
+ .withMessage(OllamaChatMessageRole.USER, "And what is the second larges city?")
+ .build();
- OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
- assertNotNull(chatResult);
- assertNotNull(chatResult.getResponse());
-
- builder.reset();
-
- requestModel =
- builder.withMessages(chatResult.getChatHistory())
- .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
-
- chatResult = ollamaAPI.chat(requestModel);
- assertNotNull(chatResult);
- assertNotNull(chatResult.getResponse());
-
-
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
+ assertNotNull(chatResult);
+ assertFalse(chatResult.getResponse().isBlank());
+ assertEquals(4, chatResult.getChatHistory().size());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testChatWithImageFromURL() {
- testEndpointReachability();
- try {
- OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
- OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
- "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
- .build();
+ @Test
+ @Order(3)
+ void testChatWithSystemPrompt() {
+ testEndpointReachability();
+ try {
+ OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
+ OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.SYSTEM,
+ "You are a silent bot that only says 'NI'. Do not say anything else under any circumstances!")
+ .withMessage(OllamaChatMessageRole.USER,
+ "What is the capital of France? And what's France's connection with Mona Lisa?")
+ .build();
- OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
- assertNotNull(chatResult);
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
+ assertNotNull(chatResult);
+ assertFalse(chatResult.getResponse().isBlank());
+ assertTrue(chatResult.getResponse().startsWith("NI"));
+ assertEquals(3, chatResult.getChatHistory().size());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testAskModelWithOptionsAndImageFiles() {
- testEndpointReachability();
- File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
- try {
- OllamaResult result =
- ollamaAPI.generateWithImageFiles(
- config.getImageModel(),
- "What is in this image?",
- List.of(imageFile),
- new OptionsBuilder().build());
- assertNotNull(result);
- assertNotNull(result.getResponse());
- assertFalse(result.getResponse().isEmpty());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ @Test
+ @Order(3)
+ void testChatWithStream() {
+ testEndpointReachability();
+ try {
+ OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getModel());
+ OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER,
+ "What is the capital of France? And what's France's connection with Mona Lisa?")
+ .build();
+
+ StringBuffer sb = new StringBuffer("");
+
+ OllamaChatResult chatResult = ollamaAPI.chat(requestModel, (s) -> {
+ LOG.info(s);
+ String substring = s.substring(sb.toString().length(), s.length());
+ LOG.info(substring);
+ sb.append(substring);
+ });
+ assertNotNull(chatResult);
+ assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testAskModelWithOptionsAndImageFilesStreamed() {
- testEndpointReachability();
- File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
- try {
- StringBuffer sb = new StringBuffer("");
+ @Test
+ @Order(3)
+ void testChatWithImageFromFileWithHistoryRecognition() {
+ testEndpointReachability();
+ try {
+ OllamaChatRequestBuilder builder =
+ OllamaChatRequestBuilder.getInstance(config.getImageModel());
+ OllamaChatRequestModel requestModel =
+ builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
+ List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
- OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
- "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
- LOG.info(s);
- String substring = s.substring(sb.toString().length(), s.length());
- LOG.info(substring);
- sb.append(substring);
- });
- assertNotNull(result);
- assertNotNull(result.getResponse());
- assertFalse(result.getResponse().isEmpty());
- assertEquals(sb.toString().trim(), result.getResponse().trim());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
+ assertNotNull(chatResult);
+ assertNotNull(chatResult.getResponse());
+
+ builder.reset();
+
+ requestModel =
+ builder.withMessages(chatResult.getChatHistory())
+ .withMessage(OllamaChatMessageRole.USER, "What's the dogs breed?").build();
+
+ chatResult = ollamaAPI.chat(requestModel);
+ assertNotNull(chatResult);
+ assertNotNull(chatResult.getResponse());
+
+
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
- @Test
- @Order(3)
- void testAskModelWithOptionsAndImageURLs() {
- testEndpointReachability();
- try {
- OllamaResult result =
- ollamaAPI.generateWithImageURLs(
- config.getImageModel(),
- "What is in this image?",
- List.of(
- "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
- new OptionsBuilder().build());
- assertNotNull(result);
- assertNotNull(result.getResponse());
- assertFalse(result.getResponse().isEmpty());
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ @Order(3)
+ void testChatWithImageFromURL() {
+ testEndpointReachability();
+ try {
+ OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
+ OllamaChatRequestModel requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
+ "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
+ .build();
+
+ OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
+ assertNotNull(chatResult);
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
+ }
+
+ @Test
+ @Order(3)
+ void testAskModelWithOptionsAndImageFiles() {
+ testEndpointReachability();
+ File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
+ try {
+ OllamaResult result =
+ ollamaAPI.generateWithImageFiles(
+ config.getImageModel(),
+ "What is in this image?",
+ List.of(imageFile),
+ new OptionsBuilder().build());
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
+ }
+
+ @Test
+ @Order(3)
+ void testAskModelWithOptionsAndImageFilesStreamed() {
+ testEndpointReachability();
+ File imageFile = getImageFileFromClasspath("dog-on-a-boat.jpg");
+ try {
+ StringBuffer sb = new StringBuffer("");
+
+ OllamaResult result = ollamaAPI.generateWithImageFiles(config.getImageModel(),
+ "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
+ LOG.info(s);
+ String substring = s.substring(sb.toString().length(), s.length());
+ LOG.info(substring);
+ sb.append(substring);
+ });
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ assertEquals(sb.toString().trim(), result.getResponse().trim());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
+ }
+
+ @Test
+ @Order(3)
+ void testAskModelWithOptionsAndImageURLs() {
+ testEndpointReachability();
+ try {
+ OllamaResult result =
+ ollamaAPI.generateWithImageURLs(
+ config.getImageModel(),
+ "What is in this image?",
+ List.of(
+ "https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg"),
+ new OptionsBuilder().build());
+ assertNotNull(result);
+ assertNotNull(result.getResponse());
+ assertFalse(result.getResponse().isEmpty());
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ fail(e);
+ }
+ }
+
+ @Test
+ @Order(3)
+ public void testEmbedding() {
+ testEndpointReachability();
+ try {
+ OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
+ .getInstance(config.getModel(), "What is the capital of France?").build();
+
+ List embeddings = ollamaAPI.generateEmbeddings(request);
+
+ assertNotNull(embeddings);
+ assertFalse(embeddings.isEmpty());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ fail(e);
+ }
}
- }
}
@Data
class Config {
- private String ollamaURL;
- private String model;
- private String imageModel;
- private int requestTimeoutSeconds;
+ private String ollamaURL;
+ private String model;
+ private String imageModel;
+ private int requestTimeoutSeconds;
- public Config() {
- Properties properties = new Properties();
- try (InputStream input =
- getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
- if (input == null) {
- throw new RuntimeException("Sorry, unable to find test-config.properties");
- }
- properties.load(input);
- this.ollamaURL = properties.getProperty("ollama.url");
- this.model = properties.getProperty("ollama.model");
- this.imageModel = properties.getProperty("ollama.model.image");
- this.requestTimeoutSeconds =
- Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
- } catch (IOException e) {
- throw new RuntimeException("Error loading properties", e);
+ public Config() {
+ Properties properties = new Properties();
+ try (InputStream input =
+ getClass().getClassLoader().getResourceAsStream("test-config.properties")) {
+ if (input == null) {
+ throw new RuntimeException("Sorry, unable to find test-config.properties");
+ }
+ properties.load(input);
+ this.ollamaURL = properties.getProperty("ollama.url");
+ this.model = properties.getProperty("ollama.model");
+ this.imageModel = properties.getProperty("ollama.model.image");
+ this.requestTimeoutSeconds =
+ Integer.parseInt(properties.getProperty("ollama.request-timeout-seconds"));
+ } catch (IOException e) {
+ throw new RuntimeException("Error loading properties", e);
+ }
}
- }
}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java
index 879c67c..c5d60e1 100644
--- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java
@@ -1,7 +1,5 @@
package io.github.amithkoujalgi.ollama4j.unittests;
-import static org.mockito.Mockito.*;
-
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
@@ -9,155 +7,158 @@ import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collections;
-import org.junit.jupiter.api.Test;
-import org.mockito.Mockito;
+
+import static org.mockito.Mockito.*;
class TestMockedAPIs {
- @Test
- void testPullModel() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- try {
- doNothing().when(ollamaAPI).pullModel(model);
- ollamaAPI.pullModel(model);
- verify(ollamaAPI, times(1)).pullModel(model);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ void testPullModel() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ try {
+ doNothing().when(ollamaAPI).pullModel(model);
+ ollamaAPI.pullModel(model);
+ verify(ollamaAPI, times(1)).pullModel(model);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testListModels() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- try {
- when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
- ollamaAPI.listModels();
- verify(ollamaAPI, times(1)).listModels();
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ void testListModels() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ try {
+ when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
+ ollamaAPI.listModels();
+ verify(ollamaAPI, times(1)).listModels();
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testCreateModel() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
- try {
- doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
- ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
- verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ void testCreateModel() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ String modelFilePath = "FROM llama2\nSYSTEM You are mario from Super Mario Bros.";
+ try {
+ doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath);
+ ollamaAPI.createModelWithModelFileContents(model, modelFilePath);
+ verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testDeleteModel() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- try {
- doNothing().when(ollamaAPI).deleteModel(model, true);
- ollamaAPI.deleteModel(model, true);
- verify(ollamaAPI, times(1)).deleteModel(model, true);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ void testDeleteModel() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ try {
+ doNothing().when(ollamaAPI).deleteModel(model, true);
+ ollamaAPI.deleteModel(model, true);
+ verify(ollamaAPI, times(1)).deleteModel(model, true);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testGetModelDetails() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- try {
- when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
- ollamaAPI.getModelDetails(model);
- verify(ollamaAPI, times(1)).getModelDetails(model);
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ void testGetModelDetails() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ try {
+ when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
+ ollamaAPI.getModelDetails(model);
+ verify(ollamaAPI, times(1)).getModelDetails(model);
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testGenerateEmbeddings() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- String prompt = "some prompt text";
- try {
- when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
- ollamaAPI.generateEmbeddings(model, prompt);
- verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ @Test
+ void testGenerateEmbeddings() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ String prompt = "some prompt text";
+ try {
+ when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
+ ollamaAPI.generateEmbeddings(model, prompt);
+ verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testAsk() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- String prompt = "some prompt text";
- OptionsBuilder optionsBuilder = new OptionsBuilder();
- try {
- when(ollamaAPI.generate(model, prompt, optionsBuilder.build()))
- .thenReturn(new OllamaResult("", 0, 200));
- ollamaAPI.generate(model, prompt, optionsBuilder.build());
- verify(ollamaAPI, times(1)).generate(model, prompt, optionsBuilder.build());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ @Test
+ void testAsk() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ String prompt = "some prompt text";
+ OptionsBuilder optionsBuilder = new OptionsBuilder();
+ try {
+ when(ollamaAPI.generate(model, prompt, false, optionsBuilder.build()))
+ .thenReturn(new OllamaResult("", 0, 200));
+ ollamaAPI.generate(model, prompt, false, optionsBuilder.build());
+ verify(ollamaAPI, times(1)).generate(model, prompt, false, optionsBuilder.build());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testAskWithImageFiles() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- String prompt = "some prompt text";
- try {
- when(ollamaAPI.generateWithImageFiles(
- model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
- .thenReturn(new OllamaResult("", 0, 200));
- ollamaAPI.generateWithImageFiles(
- model, prompt, Collections.emptyList(), new OptionsBuilder().build());
- verify(ollamaAPI, times(1))
- .generateWithImageFiles(
- model, prompt, Collections.emptyList(), new OptionsBuilder().build());
- } catch (IOException | OllamaBaseException | InterruptedException e) {
- throw new RuntimeException(e);
+ @Test
+ void testAskWithImageFiles() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ String prompt = "some prompt text";
+ try {
+ when(ollamaAPI.generateWithImageFiles(
+ model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
+ .thenReturn(new OllamaResult("", 0, 200));
+ ollamaAPI.generateWithImageFiles(
+ model, prompt, Collections.emptyList(), new OptionsBuilder().build());
+ verify(ollamaAPI, times(1))
+ .generateWithImageFiles(
+ model, prompt, Collections.emptyList(), new OptionsBuilder().build());
+ } catch (IOException | OllamaBaseException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testAskWithImageURLs() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- String prompt = "some prompt text";
- try {
- when(ollamaAPI.generateWithImageURLs(
- model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
- .thenReturn(new OllamaResult("", 0, 200));
- ollamaAPI.generateWithImageURLs(
- model, prompt, Collections.emptyList(), new OptionsBuilder().build());
- verify(ollamaAPI, times(1))
- .generateWithImageURLs(
- model, prompt, Collections.emptyList(), new OptionsBuilder().build());
- } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
- throw new RuntimeException(e);
+ @Test
+ void testAskWithImageURLs() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ String prompt = "some prompt text";
+ try {
+ when(ollamaAPI.generateWithImageURLs(
+ model, prompt, Collections.emptyList(), new OptionsBuilder().build()))
+ .thenReturn(new OllamaResult("", 0, 200));
+ ollamaAPI.generateWithImageURLs(
+ model, prompt, Collections.emptyList(), new OptionsBuilder().build());
+ verify(ollamaAPI, times(1))
+ .generateWithImageURLs(
+ model, prompt, Collections.emptyList(), new OptionsBuilder().build());
+ } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
}
- }
- @Test
- void testAskAsync() {
- OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
- String model = OllamaModelType.LLAMA2;
- String prompt = "some prompt text";
- when(ollamaAPI.generateAsync(model, prompt))
- .thenReturn(new OllamaAsyncResultCallback(null, null, 3));
- ollamaAPI.generateAsync(model, prompt);
- verify(ollamaAPI, times(1)).generateAsync(model, prompt);
- }
+ @Test
+ void testAskAsync() {
+ OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
+ String model = OllamaModelType.LLAMA2;
+ String prompt = "some prompt text";
+ when(ollamaAPI.generateAsync(model, prompt, false))
+ .thenReturn(new OllamaAsyncResultCallback(null, null, 3));
+ ollamaAPI.generateAsync(model, prompt, false);
+ verify(ollamaAPI, times(1)).generateAsync(model, prompt, false);
+ }
}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/AbstractSerializationTest.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/AbstractSerializationTest.java
new file mode 100644
index 0000000..d0ffc2c
--- /dev/null
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/AbstractSerializationTest.java
@@ -0,0 +1,35 @@
+package io.github.amithkoujalgi.ollama4j.unittests.jackson;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
+
+public abstract class AbstractSerializationTest {
+
+ protected ObjectMapper mapper = Utils.getObjectMapper();
+
+ protected String serialize(T obj) {
+ try {
+ return mapper.writeValueAsString(obj);
+ } catch (JsonProcessingException e) {
+ fail("Could not serialize request!", e);
+ return null;
+ }
+ }
+
+ protected T deserialize(String jsonObject, Class deserializationClass) {
+ try {
+ return mapper.readValue(jsonObject, deserializationClass);
+ } catch (JsonProcessingException e) {
+ fail("Could not deserialize jsonObject!", e);
+ return null;
+ }
+ }
+
+ protected void assertEqualsAfterUnmarshalling(T unmarshalledObject,
+ T req) {
+ assertEquals(req, unmarshalledObject);
+ }
+}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java
index f5fa5c9..3ad049c 100644
--- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestChatRequestSerialization.java
@@ -1,7 +1,6 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.fail;
import java.io.File;
import java.util.List;
@@ -10,21 +9,15 @@ import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.ObjectMapper;
-
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
-import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
-public class TestChatRequestSerialization {
+public class TestChatRequestSerialization extends AbstractSerializationTest {
private OllamaChatRequestBuilder builder;
- private ObjectMapper mapper = Utils.getObjectMapper();
-
@BeforeEach
public void init() {
builder = OllamaChatRequestBuilder.getInstance("DummyModel");
@@ -32,10 +25,9 @@ public class TestChatRequestSerialization {
@Test
public void testRequestOnlyMandatoryFields() {
- OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
- List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
- String jsonRequest = serializeRequest(req);
- assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
+ OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt").build();
+ String jsonRequest = serialize(req);
+ assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaChatRequestModel.class), req);
}
@Test
@@ -43,28 +35,43 @@ public class TestChatRequestSerialization {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt")
.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.build();
- String jsonRequest = serializeRequest(req);
- assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
+ String jsonRequest = serialize(req);
+ assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaChatRequestModel.class), req);
}
@Test
public void testRequestWithMessageAndImage() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
- String jsonRequest = serializeRequest(req);
- assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
+ String jsonRequest = serialize(req);
+ assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaChatRequestModel.class), req);
}
@Test
public void testRequestWithOptions() {
OptionsBuilder b = new OptionsBuilder();
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
- .withOptions(b.setMirostat(1).build()).build();
+ .withOptions(b.setMirostat(1).build())
+ .withOptions(b.setTemperature(1L).build())
+ .withOptions(b.setMirostatEta(1L).build())
+ .withOptions(b.setMirostatTau(1L).build())
+ .withOptions(b.setNumGpu(1).build())
+ .withOptions(b.setSeed(1).build())
+ .withOptions(b.setTopK(1).build())
+ .withOptions(b.setTopP(1).build())
+ .build();
- String jsonRequest = serializeRequest(req);
- OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest);
+ String jsonRequest = serialize(req);
+ OllamaChatRequestModel deserializeRequest = deserialize(jsonRequest, OllamaChatRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
+ assertEquals(1.0, deserializeRequest.getOptions().get("temperature"));
+ assertEquals(1.0, deserializeRequest.getOptions().get("mirostat_eta"));
+ assertEquals(1.0, deserializeRequest.getOptions().get("mirostat_tau"));
+ assertEquals(1, deserializeRequest.getOptions().get("num_gpu"));
+ assertEquals(1, deserializeRequest.getOptions().get("seed"));
+ assertEquals(1, deserializeRequest.getOptions().get("top_k"));
+ assertEquals(1.0, deserializeRequest.getOptions().get("top_p"));
}
@Test
@@ -72,7 +79,7 @@ public class TestChatRequestSerialization {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.withGetJsonResponse().build();
- String jsonRequest = serializeRequest(req);
+ String jsonRequest = serialize(req);
// no jackson deserialization as format property is not boolean ==> omit as deserialization
// of request is never used in real code anyways
JSONObject jsonObject = new JSONObject(jsonRequest);
@@ -80,27 +87,27 @@ public class TestChatRequestSerialization {
assertEquals("json", requestFormatProperty);
}
- private String serializeRequest(OllamaChatRequestModel req) {
- try {
- return mapper.writeValueAsString(req);
- } catch (JsonProcessingException e) {
- fail("Could not serialize request!", e);
- return null;
- }
+ @Test
+ public void testWithTemplate() {
+ OllamaChatRequestModel req = builder.withTemplate("System Template")
+ .build();
+ String jsonRequest = serialize(req);
+ assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequestModel.class), req);
}
- private OllamaChatRequestModel deserializeRequest(String jsonRequest) {
- try {
- return mapper.readValue(jsonRequest, OllamaChatRequestModel.class);
- } catch (JsonProcessingException e) {
- fail("Could not deserialize jsonRequest!", e);
- return null;
- }
+ @Test
+ public void testWithStreaming() {
+ OllamaChatRequestModel req = builder.withStreaming().build();
+ String jsonRequest = serialize(req);
+ assertEquals(deserialize(jsonRequest, OllamaChatRequestModel.class).isStream(), true);
}
- private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest,
- OllamaChatRequestModel req) {
- assertEquals(req, unmarshalledRequest);
+ @Test
+ public void testWithKeepAlive() {
+ String expectedKeepAlive = "5m";
+ OllamaChatRequestModel req = builder.withKeepAlive(expectedKeepAlive)
+ .build();
+ String jsonRequest = serialize(req);
+ assertEquals(deserialize(jsonRequest, OllamaChatRequestModel.class).getKeepAlive(), expectedKeepAlive);
}
-
}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestEmbeddingsRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestEmbeddingsRequestSerialization.java
new file mode 100644
index 0000000..a546d6d
--- /dev/null
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestEmbeddingsRequestSerialization.java
@@ -0,0 +1,37 @@
+package io.github.amithkoujalgi.ollama4j.unittests.jackson;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
+import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
+import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
+
+public class TestEmbeddingsRequestSerialization extends AbstractSerializationTest {
+
+ private OllamaEmbeddingsRequestBuilder builder;
+
+ @BeforeEach
+ public void init() {
+ builder = OllamaEmbeddingsRequestBuilder.getInstance("DummyModel","DummyPrompt");
+ }
+
+ @Test
+ public void testRequestOnlyMandatoryFields() {
+ OllamaEmbeddingsRequestModel req = builder.build();
+ String jsonRequest = serialize(req);
+ assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaEmbeddingsRequestModel.class), req);
+ }
+
+ @Test
+ public void testRequestWithOptions() {
+ OptionsBuilder b = new OptionsBuilder();
+ OllamaEmbeddingsRequestModel req = builder
+ .withOptions(b.setMirostat(1).build()).build();
+
+ String jsonRequest = serialize(req);
+ OllamaEmbeddingsRequestModel deserializeRequest = deserialize(jsonRequest,OllamaEmbeddingsRequestModel.class);
+ assertEqualsAfterUnmarshalling(deserializeRequest, req);
+ assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
+ }
+}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java
index 7cf0513..8e95288 100644
--- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestGenerateRequestSerialization.java
@@ -1,26 +1,20 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.fail;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
-import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
-public class TestGenerateRequestSerialization {
+public class TestGenerateRequestSerialization extends AbstractSerializationTest {
private OllamaGenerateRequestBuilder builder;
- private ObjectMapper mapper = Utils.getObjectMapper();
-
@BeforeEach
public void init() {
builder = OllamaGenerateRequestBuilder.getInstance("DummyModel");
@@ -30,8 +24,8 @@ public class TestGenerateRequestSerialization {
public void testRequestOnlyMandatoryFields() {
OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build();
- String jsonRequest = serializeRequest(req);
- assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
+ String jsonRequest = serialize(req);
+ assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaGenerateRequestModel.class), req);
}
@Test
@@ -40,8 +34,8 @@ public class TestGenerateRequestSerialization {
OllamaGenerateRequestModel req =
builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build();
- String jsonRequest = serializeRequest(req);
- OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest);
+ String jsonRequest = serialize(req);
+ OllamaGenerateRequestModel deserializeRequest = deserialize(jsonRequest, OllamaGenerateRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
}
@@ -51,7 +45,7 @@ public class TestGenerateRequestSerialization {
OllamaGenerateRequestModel req =
builder.withPrompt("Some prompt").withGetJsonResponse().build();
- String jsonRequest = serializeRequest(req);
+ String jsonRequest = serialize(req);
// no jackson deserialization as format property is not boolean ==> omit as deserialization
// of request is never used in real code anyways
JSONObject jsonObject = new JSONObject(jsonRequest);
@@ -59,27 +53,4 @@ public class TestGenerateRequestSerialization {
assertEquals("json", requestFormatProperty);
}
- private String serializeRequest(OllamaGenerateRequestModel req) {
- try {
- return mapper.writeValueAsString(req);
- } catch (JsonProcessingException e) {
- fail("Could not serialize request!", e);
- return null;
- }
- }
-
- private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) {
- try {
- return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class);
- } catch (JsonProcessingException e) {
- fail("Could not deserialize jsonRequest!", e);
- return null;
- }
- }
-
- private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest,
- OllamaGenerateRequestModel req) {
- assertEquals(req, unmarshalledRequest);
- }
-
}
diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestModelRequestSerialization.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestModelRequestSerialization.java
new file mode 100644
index 0000000..712e507
--- /dev/null
+++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/jackson/TestModelRequestSerialization.java
@@ -0,0 +1,42 @@
+package io.github.amithkoujalgi.ollama4j.unittests.jackson;
+
+import io.github.amithkoujalgi.ollama4j.core.models.Model;
+import org.junit.jupiter.api.Test;
+
+public class TestModelRequestSerialization extends AbstractSerializationTest {
+
+ @Test
+ public void testDeserializationOfModelResponseWithOffsetTime(){
+ String serializedTestStringWithOffsetTime = "{\n"
+ + "\"name\": \"codellama:13b\",\n"
+ + "\"modified_at\": \"2023-11-04T14:56:49.277302595-07:00\",\n"
+ + "\"size\": 7365960935,\n"
+ + "\"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n"
+ + "\"details\": {\n"
+ + "\"format\": \"gguf\",\n"
+ + "\"family\": \"llama\",\n"
+ + "\"families\": null,\n"
+ + "\"parameter_size\": \"13B\",\n"
+ + "\"quantization_level\": \"Q4_0\"\n"
+ + "}}";
+ deserialize(serializedTestStringWithOffsetTime,Model.class);
+ }
+
+ @Test
+ public void testDeserializationOfModelResponseWithZuluTime(){
+ String serializedTestStringWithZuluTimezone = "{\n"
+ + "\"name\": \"codellama:13b\",\n"
+ + "\"modified_at\": \"2023-11-04T14:56:49.277302595Z\",\n"
+ + "\"size\": 7365960935,\n"
+ + "\"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n"
+ + "\"details\": {\n"
+ + "\"format\": \"gguf\",\n"
+ + "\"family\": \"llama\",\n"
+ + "\"families\": null,\n"
+ + "\"parameter_size\": \"13B\",\n"
+ + "\"quantization_level\": \"Q4_0\"\n"
+ + "}}";
+ deserialize(serializedTestStringWithZuluTimezone,Model.class);
+ }
+
+}