Compare commits

..

30 Commits

Author SHA1 Message Date
amithkoujalgi
bb0785140b [maven-release-plugin] prepare for next development iteration 2024-05-20 14:55:02 +00:00
amithkoujalgi
e33ad1a1e3 [maven-release-plugin] prepare release v1.0.72 2024-05-20 14:55:01 +00:00
Amith Koujalgi
cd60c506cb Merge pull request #47 from kelvinwatson/kelvinwatson/gradleDependencyInReadMe
Update README to include gradle project set up options
2024-05-20 20:24:03 +05:30
amithkoujalgi
b55925df28 [maven-release-plugin] prepare for next development iteration 2024-05-20 14:53:55 +00:00
amithkoujalgi
3a9b8c309d [maven-release-plugin] prepare release v1.0.71 2024-05-20 14:53:54 +00:00
Amith Koujalgi
bf07159522 Merge pull request #49 from AgentSchmecker/bugfix/48
Changes Datatype of ModelResponse fields to OffsetTime
2024-05-20 20:22:56 +05:30
AgentSchmecker
f8ca4d041d Changes DateTime types of Model.java to OffsetDatetime
Fixes #48
2024-05-20 11:10:03 +00:00
AgentSchmecker
9c6a55f7b0 Generalizes Abstract Serialization Test Class
Removes the "Request" naming context as this base class technically serves for general serialization purposes.
2024-05-20 11:08:49 +00:00
Kelvin Watson
2866d83a2f Update README.md 2024-05-19 11:58:08 -07:00
Kelvin Watson
45e5d07581 update README to include gradle options 2024-05-19 11:57:09 -07:00
amithkoujalgi
3a264cb6bb [maven-release-plugin] prepare for next development iteration 2024-05-19 13:57:34 +00:00
amithkoujalgi
e1b9d42771 [maven-release-plugin] prepare release v1.0.70 2024-05-19 13:57:32 +00:00
Amith Koujalgi
1a086c37c0 Merge pull request #46 from AgentSchmecker/model_update
Updates Model.java to be up to date with current OllamaAPI
2024-05-19 19:26:28 +05:30
Markus Klenke
54edba144c Merge pull request #2 from AgentSchmecker/model_update
Updates Model.java to be up to date with current OllamaAPI
2024-05-17 00:09:15 +02:00
AgentSchmecker
3ed3187ba9 Updates Model.java to be up to date with current OllamaAPI
Also adds Jackson-JSR310 for java.time JSON Mapping
2024-05-16 22:00:11 +00:00
amithkoujalgi
b7cd81a7f5 [maven-release-plugin] prepare for next development iteration 2024-05-14 05:42:56 +00:00
amithkoujalgi
e750c2d7f9 [maven-release-plugin] prepare release v1.0.69 2024-05-14 05:42:55 +00:00
Amith Koujalgi
62f16131f3 Merge remote-tracking branch 'origin/main' 2024-05-14 11:11:51 +05:30
Amith Koujalgi
2cbaf12d7c Updated library usages in README.md 2024-05-14 11:11:38 +05:30
amithkoujalgi
e2d555d404 [maven-release-plugin] prepare for next development iteration 2024-05-14 05:29:36 +00:00
amithkoujalgi
c296b34174 [maven-release-plugin] prepare release v1.0.68 2024-05-14 05:29:35 +00:00
Amith Koujalgi
e8f99f28ec Updated library usages in README.md 2024-05-14 10:58:29 +05:30
amithkoujalgi
250b1abc79 [maven-release-plugin] prepare for next development iteration 2024-05-14 05:07:20 +00:00
amithkoujalgi
42b15ad93f [maven-release-plugin] prepare release v1.0.67 2024-05-14 05:07:18 +00:00
Amith Koujalgi
6f7a714bae Merge remote-tracking branch 'origin/main' 2024-05-14 10:36:04 +05:30
Amith Koujalgi
92618e5084 Updated OllamaChatResponseModel to include done_reason field. Refer to the Ollama version: https://github.com/ollama/ollama/releases/tag/v0.1.37 2024-05-14 10:35:55 +05:30
amithkoujalgi
391a9242c3 [maven-release-plugin] prepare for next development iteration 2024-05-14 04:59:08 +00:00
amithkoujalgi
e1b6dc3b54 [maven-release-plugin] prepare release v1.0.66 2024-05-14 04:59:07 +00:00
Amith Koujalgi
04124cf978 Updated default request timeout to 10 seconds 2024-05-14 10:27:56 +05:30
amithkoujalgi
e4e717b747 [maven-release-plugin] prepare for next development iteration 2024-05-13 15:36:38 +00:00
14 changed files with 769 additions and 694 deletions

View File

@@ -67,10 +67,29 @@ In your Maven project, add this dependency:
<dependency> <dependency>
<groupId>io.github.amithkoujalgi</groupId> <groupId>io.github.amithkoujalgi</groupId>
<artifactId>ollama4j</artifactId> <artifactId>ollama4j</artifactId>
<version>1.0.57</version> <version>1.0.70</version>
</dependency> </dependency>
``` ```
or
In your Gradle project, add the dependency using the Kotlin DSL or the Groovy DSL:
```kotlin
dependencies {
val ollama4jVersion = "1.0.70"
implementation("io.github.amithkoujalgi:ollama4j:$ollama4jVersion")
}
```
```groovy
dependencies {
implementation("io.github.amithkoujalgi:ollama4j:1.0.70")
}
```
Latest release: Latest release:
![Maven Central](https://img.shields.io/maven-central/v/io.github.amithkoujalgi/ollama4j) ![Maven Central](https://img.shields.io/maven-central/v/io.github.amithkoujalgi/ollama4j)
@@ -110,6 +129,16 @@ make it
Releases (newer artifact versions) are done automatically on pushing the code to the `main` branch through GitHub Releases (newer artifact versions) are done automatically on pushing the code to the `main` branch through GitHub
Actions CI workflow. Actions CI workflow.
#### Who's using Ollama4j?
- `Datafaker`: a library to generate fake data
- https://github.com/datafaker-net/datafaker-experimental/tree/main/ollama-api
- `Vaadin Web UI`: UI-Tester for Interactions with Ollama via ollama4j
- https://github.com/TEAMPB/ollama4j-vaadin-ui
- `ollama-translator`: Minecraft 1.20.6 spigot plugin allows to easily break language barriers by using ollama on the
server to translate all messages into a specfic target language.
- https://github.com/liebki/ollama-translator
#### Traction #### Traction
[![Star History Chart](https://api.star-history.com/svg?repos=amithkoujalgi/ollama4j&type=Date)](https://star-history.com/#amithkoujalgi/ollama4j&Date) [![Star History Chart](https://api.star-history.com/svg?repos=amithkoujalgi/ollama4j&type=Date)](https://star-history.com/#amithkoujalgi/ollama4j&Date)
@@ -150,4 +179,4 @@ project.
### References ### References
- [Ollama REST APIs](https://github.com/jmorganca/ollama/blob/main/docs/api.md) - [Ollama REST APIs](https://github.com/jmorganca/ollama/blob/main/docs/api.md)

View File

@@ -112,7 +112,7 @@ You will get a response similar to:
## Use a simple Console Output Stream Handler ## Use a simple Console Output Stream Handler
``` ```java
import io.github.amithkoujalgi.ollama4j.core.impl.ConsoleOutputStreamHandler; import io.github.amithkoujalgi.ollama4j.core.impl.ConsoleOutputStreamHandler;
public class Main { public class Main {

11
pom.xml
View File

@@ -4,7 +4,7 @@
<groupId>io.github.amithkoujalgi</groupId> <groupId>io.github.amithkoujalgi</groupId>
<artifactId>ollama4j</artifactId> <artifactId>ollama4j</artifactId>
<version>1.0.65</version> <version>1.0.73-SNAPSHOT</version>
<name>Ollama4j</name> <name>Ollama4j</name>
<description>Java library for interacting with Ollama API.</description> <description>Java library for interacting with Ollama API.</description>
@@ -39,7 +39,7 @@
<connection>scm:git:git@github.com:amithkoujalgi/ollama4j.git</connection> <connection>scm:git:git@github.com:amithkoujalgi/ollama4j.git</connection>
<developerConnection>scm:git:https://github.com/amithkoujalgi/ollama4j.git</developerConnection> <developerConnection>scm:git:https://github.com/amithkoujalgi/ollama4j.git</developerConnection>
<url>https://github.com/amithkoujalgi/ollama4j</url> <url>https://github.com/amithkoujalgi/ollama4j</url>
<tag>v1.0.65</tag> <tag>v1.0.16</tag>
</scm> </scm>
<build> <build>
@@ -149,7 +149,12 @@
<dependency> <dependency>
<groupId>com.fasterxml.jackson.core</groupId> <groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId> <artifactId>jackson-databind</artifactId>
<version>2.15.3</version> <version>2.17.1</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
<version>2.17.1</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>

View File

@@ -1,5 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.models; package io.github.amithkoujalgi.ollama4j.core.models;
import java.time.LocalDateTime;
import java.time.OffsetDateTime;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
@@ -11,7 +14,9 @@ public class Model {
private String name; private String name;
private String model; private String model;
@JsonProperty("modified_at") @JsonProperty("modified_at")
private String modifiedAt; private OffsetDateTime modifiedAt;
@JsonProperty("expires_at")
private OffsetDateTime expiresAt;
private String digest; private String digest;
private long size; private long size;
@JsonProperty("details") @JsonProperty("details")

View File

@@ -1,14 +1,15 @@
package io.github.amithkoujalgi.ollama4j.core.models.chat; package io.github.amithkoujalgi.ollama4j.core.models.chat;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.List; import java.util.List;
import lombok.Data;
@Data @Data
public class OllamaChatResponseModel { public class OllamaChatResponseModel {
private String model; private String model;
private @JsonProperty("created_at") String createdAt; private @JsonProperty("created_at") String createdAt;
private @JsonProperty("done_reason") String doneReason;
private OllamaChatMessage message; private OllamaChatMessage message;
private boolean done; private boolean done;
private String error; private String error;

View File

@@ -1,12 +1,6 @@
package io.github.amithkoujalgi.ollama4j.core.models.request; package io.github.amithkoujalgi.ollama4j.core.models.request;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler; import io.github.amithkoujalgi.ollama4j.core.OllamaStreamHandler;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
@@ -15,11 +9,15 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatStreamObserver; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatStreamObserver;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
/** /**
* Specialization class for requests * Specialization class for requests
*/ */
public class OllamaChatEndpointCaller extends OllamaEndpointCaller{ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class);
@@ -39,14 +37,14 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
try { try {
OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
responseBuffer.append(ollamaResponseModel.getMessage().getContent()); responseBuffer.append(ollamaResponseModel.getMessage().getContent());
if(streamObserver != null) { if (streamObserver != null) {
streamObserver.notify(ollamaResponseModel); streamObserver.notify(ollamaResponseModel);
} }
return ollamaResponseModel.isDone(); return ollamaResponseModel.isDone();
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
LOG.error("Error parsing the Ollama chat response!",e); LOG.error("Error parsing the Ollama chat response!", e);
return true; return true;
} }
} }
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler) public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
@@ -54,7 +52,4 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller{
streamObserver = new OllamaChatStreamObserver(streamHandler); streamObserver = new OllamaChatStreamObserver(streamHandler);
return super.callSync(body); return super.callSync(body);
} }
} }

View File

@@ -1,5 +1,15 @@
package io.github.amithkoujalgi.ollama4j.core.models.request; package io.github.amithkoujalgi.ollama4j.core.models.request;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaErrorResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@@ -12,22 +22,11 @@ import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.Base64; import java.util.Base64;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaErrorResponseModel;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
/** /**
* Abstract helperclass to call the ollama api server. * Abstract helperclass to call the ollama api server.
*/ */
public abstract class OllamaEndpointCaller { public abstract class OllamaEndpointCaller {
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
private String host; private String host;
@@ -49,107 +48,105 @@ public abstract class OllamaEndpointCaller {
/** /**
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response. * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
* *
* @param body POST body payload * @param body POST body payload
* @return result answer given by the assistant * @return result answer given by the assistant
* @throws OllamaBaseException any response code than 200 has been returned * @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read * @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network issues happen * @throws InterruptedException in case the server is not reachable or network issues happen
*/ */
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException{ public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
// Create Request // Create Request
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
URI uri = URI.create(this.host + getEndpointSuffix()); URI uri = URI.create(this.host + getEndpointSuffix());
HttpRequest.Builder requestBuilder = HttpRequest.Builder requestBuilder =
getRequestBuilderDefault(uri) getRequestBuilderDefault(uri)
.POST( .POST(
body.getBodyPublisher()); body.getBodyPublisher());
HttpRequest request = requestBuilder.build(); HttpRequest request = requestBuilder.build();
if (this.verbose) LOG.info("Asking model: " + body.toString()); if (this.verbose) LOG.info("Asking model: " + body.toString());
HttpResponse<InputStream> response = HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
InputStream responseBodyStream = response.body(); InputStream responseBodyStream = response.body();
StringBuilder responseBuffer = new StringBuilder(); StringBuilder responseBuffer = new StringBuilder();
try (BufferedReader reader = try (BufferedReader reader =
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
if (statusCode == 404) { if (statusCode == 404) {
LOG.warn("Status code: 404 (Not Found)"); LOG.warn("Status code: 404 (Not Found)");
OllamaErrorResponseModel ollamaResponseModel = OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class); Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 401) { } else if (statusCode == 401) {
LOG.warn("Status code: 401 (Unauthorized)"); LOG.warn("Status code: 401 (Unauthorized)");
OllamaErrorResponseModel ollamaResponseModel = OllamaErrorResponseModel ollamaResponseModel =
Utils.getObjectMapper() Utils.getObjectMapper()
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else if (statusCode == 400) { } else if (statusCode == 400) {
LOG.warn("Status code: 400 (Bad Request)"); LOG.warn("Status code: 400 (Bad Request)");
OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaErrorResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line,
OllamaErrorResponseModel.class); OllamaErrorResponseModel.class);
responseBuffer.append(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError());
} else { } else {
boolean finished = parseResponseAndAddToBuffer(line,responseBuffer); boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
if (finished) { if (finished) {
break; break;
}
}
} }
} }
}
}
if (statusCode != 200) { if (statusCode != 200) {
LOG.error("Status code " + statusCode); LOG.error("Status code " + statusCode);
throw new OllamaBaseException(responseBuffer.toString()); throw new OllamaBaseException(responseBuffer.toString());
} else { } else {
long endTime = System.currentTimeMillis(); long endTime = System.currentTimeMillis();
OllamaResult ollamaResult = OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode); new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
if (verbose) LOG.info("Model response: " + ollamaResult); if (verbose) LOG.info("Model response: " + ollamaResult);
return ollamaResult; return ollamaResult;
}
} }
}
/** /**
* Get default request builder. * Get default request builder.
* *
* @param uri URI to get a HttpRequest.Builder * @param uri URI to get a HttpRequest.Builder
* @return HttpRequest.Builder * @return HttpRequest.Builder
*/ */
private HttpRequest.Builder getRequestBuilderDefault(URI uri) { private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder = HttpRequest.Builder requestBuilder =
HttpRequest.newBuilder(uri) HttpRequest.newBuilder(uri)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(this.requestTimeoutSeconds)); .timeout(Duration.ofSeconds(this.requestTimeoutSeconds));
if (isBasicAuthCredentialsSet()) { if (isBasicAuthCredentialsSet()) {
requestBuilder.header("Authorization", getBasicAuthHeaderValue()); requestBuilder.header("Authorization", getBasicAuthHeaderValue());
}
return requestBuilder;
} }
return requestBuilder;
}
/** /**
* Get basic authentication header value. * Get basic authentication header value.
* *
* @return basic authentication header value (encoded credentials) * @return basic authentication header value (encoded credentials)
*/ */
private String getBasicAuthHeaderValue() { private String getBasicAuthHeaderValue() {
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword(); String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
} }
/**
* Check if Basic Auth credentials set.
*
* @return true when Basic Auth credentials set
*/
private boolean isBasicAuthCredentialsSet() {
return this.basicAuth != null;
}
/**
* Check if Basic Auth credentials set.
*
* @return true when Basic Auth credentials set
*/
private boolean isBasicAuthCredentialsSet() {
return this.basicAuth != null;
}
} }

View File

@@ -8,10 +8,18 @@ import java.net.URISyntaxException;
import java.net.URL; import java.net.URL;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
public class Utils { public class Utils {
private static ObjectMapper objectMapper;
public static ObjectMapper getObjectMapper() { public static ObjectMapper getObjectMapper() {
return new ObjectMapper(); if(objectMapper == null) {
objectMapper = new ObjectMapper();
objectMapper.registerModule(new JavaTimeModule());
}
return objectMapper;
} }
public static byte[] loadImageBytesFromUrl(String imageUrl) public static byte[] loadImageBytesFromUrl(String imageUrl)

View File

@@ -6,30 +6,30 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
public abstract class AbstractRequestSerializationTest<T> { public abstract class AbstractSerializationTest<T> {
protected ObjectMapper mapper = Utils.getObjectMapper(); protected ObjectMapper mapper = Utils.getObjectMapper();
protected String serializeRequest(T req) { protected String serialize(T obj) {
try { try {
return mapper.writeValueAsString(req); return mapper.writeValueAsString(obj);
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
fail("Could not serialize request!", e); fail("Could not serialize request!", e);
return null; return null;
} }
} }
protected T deserializeRequest(String jsonRequest, Class<T> requestClass) { protected T deserialize(String jsonObject, Class<T> deserializationClass) {
try { try {
return mapper.readValue(jsonRequest, requestClass); return mapper.readValue(jsonObject, deserializationClass);
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
fail("Could not deserialize jsonRequest!", e); fail("Could not deserialize jsonObject!", e);
return null; return null;
} }
} }
protected void assertEqualsAfterUnmarshalling(T unmarshalledRequest, protected void assertEqualsAfterUnmarshalling(T unmarshalledObject,
T req) { T req) {
assertEquals(req, unmarshalledRequest); assertEquals(req, unmarshalledObject);
} }
} }

View File

@@ -14,7 +14,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilde
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
public class TestChatRequestSerialization extends AbstractRequestSerializationTest<OllamaChatRequestModel>{ public class TestChatRequestSerialization extends AbstractSerializationTest<OllamaChatRequestModel> {
private OllamaChatRequestBuilder builder; private OllamaChatRequestBuilder builder;
@@ -26,8 +26,8 @@ public class TestChatRequestSerialization extends AbstractRequestSerializationTe
@Test @Test
public void testRequestOnlyMandatoryFields() { public void testRequestOnlyMandatoryFields() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt").build(); OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt").build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req); assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaChatRequestModel.class), req);
} }
@Test @Test
@@ -35,16 +35,16 @@ public class TestChatRequestSerialization extends AbstractRequestSerializationTe
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt") OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt")
.withMessage(OllamaChatMessageRole.USER, "Some prompt") .withMessage(OllamaChatMessageRole.USER, "Some prompt")
.build(); .build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req); assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaChatRequestModel.class), req);
} }
@Test @Test
public void testRequestWithMessageAndImage() { public void testRequestWithMessageAndImage() {
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build(); List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req); assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaChatRequestModel.class), req);
} }
@Test @Test
@@ -61,8 +61,8 @@ public class TestChatRequestSerialization extends AbstractRequestSerializationTe
.withOptions(b.setTopP(1).build()) .withOptions(b.setTopP(1).build())
.build(); .build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest, OllamaChatRequestModel.class); OllamaChatRequestModel deserializeRequest = deserialize(jsonRequest, OllamaChatRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req); assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat")); assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
assertEquals(1.0, deserializeRequest.getOptions().get("temperature")); assertEquals(1.0, deserializeRequest.getOptions().get("temperature"));
@@ -79,7 +79,7 @@ public class TestChatRequestSerialization extends AbstractRequestSerializationTe
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt") OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.withGetJsonResponse().build(); .withGetJsonResponse().build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
// no jackson deserialization as format property is not boolean ==> omit as deserialization // no jackson deserialization as format property is not boolean ==> omit as deserialization
// of request is never used in real code anyways // of request is never used in real code anyways
JSONObject jsonObject = new JSONObject(jsonRequest); JSONObject jsonObject = new JSONObject(jsonRequest);
@@ -91,15 +91,15 @@ public class TestChatRequestSerialization extends AbstractRequestSerializationTe
public void testWithTemplate() { public void testWithTemplate() {
OllamaChatRequestModel req = builder.withTemplate("System Template") OllamaChatRequestModel req = builder.withTemplate("System Template")
.build(); .build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest, OllamaChatRequestModel.class), req); assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequestModel.class), req);
} }
@Test @Test
public void testWithStreaming() { public void testWithStreaming() {
OllamaChatRequestModel req = builder.withStreaming().build(); OllamaChatRequestModel req = builder.withStreaming().build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
assertEquals(deserializeRequest(jsonRequest, OllamaChatRequestModel.class).isStream(), true); assertEquals(deserialize(jsonRequest, OllamaChatRequestModel.class).isStream(), true);
} }
@Test @Test
@@ -107,7 +107,7 @@ public class TestChatRequestSerialization extends AbstractRequestSerializationTe
String expectedKeepAlive = "5m"; String expectedKeepAlive = "5m";
OllamaChatRequestModel req = builder.withKeepAlive(expectedKeepAlive) OllamaChatRequestModel req = builder.withKeepAlive(expectedKeepAlive)
.build(); .build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
assertEquals(deserializeRequest(jsonRequest, OllamaChatRequestModel.class).getKeepAlive(), expectedKeepAlive); assertEquals(deserialize(jsonRequest, OllamaChatRequestModel.class).getKeepAlive(), expectedKeepAlive);
} }
} }

View File

@@ -7,7 +7,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsR
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
public class TestEmbeddingsRequestSerialization extends AbstractRequestSerializationTest<OllamaEmbeddingsRequestModel>{ public class TestEmbeddingsRequestSerialization extends AbstractSerializationTest<OllamaEmbeddingsRequestModel> {
private OllamaEmbeddingsRequestBuilder builder; private OllamaEmbeddingsRequestBuilder builder;
@@ -19,8 +19,8 @@ public class TestEmbeddingsRequestSerialization extends AbstractRequestSerializa
@Test @Test
public void testRequestOnlyMandatoryFields() { public void testRequestOnlyMandatoryFields() {
OllamaEmbeddingsRequestModel req = builder.build(); OllamaEmbeddingsRequestModel req = builder.build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaEmbeddingsRequestModel.class), req); assertEqualsAfterUnmarshalling(deserialize(jsonRequest,OllamaEmbeddingsRequestModel.class), req);
} }
@Test @Test
@@ -29,8 +29,8 @@ public class TestEmbeddingsRequestSerialization extends AbstractRequestSerializa
OllamaEmbeddingsRequestModel req = builder OllamaEmbeddingsRequestModel req = builder
.withOptions(b.setMirostat(1).build()).build(); .withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
OllamaEmbeddingsRequestModel deserializeRequest = deserializeRequest(jsonRequest,OllamaEmbeddingsRequestModel.class); OllamaEmbeddingsRequestModel deserializeRequest = deserialize(jsonRequest,OllamaEmbeddingsRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req); assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat")); assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
} }

View File

@@ -11,7 +11,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateReque
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
public class TestGenerateRequestSerialization extends AbstractRequestSerializationTest<OllamaGenerateRequestModel>{ public class TestGenerateRequestSerialization extends AbstractSerializationTest<OllamaGenerateRequestModel> {
private OllamaGenerateRequestBuilder builder; private OllamaGenerateRequestBuilder builder;
@@ -24,8 +24,8 @@ public class TestGenerateRequestSerialization extends AbstractRequestSerializati
public void testRequestOnlyMandatoryFields() { public void testRequestOnlyMandatoryFields() {
OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build(); OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest, OllamaGenerateRequestModel.class), req); assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaGenerateRequestModel.class), req);
} }
@Test @Test
@@ -34,8 +34,8 @@ public class TestGenerateRequestSerialization extends AbstractRequestSerializati
OllamaGenerateRequestModel req = OllamaGenerateRequestModel req =
builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build(); builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest, OllamaGenerateRequestModel.class); OllamaGenerateRequestModel deserializeRequest = deserialize(jsonRequest, OllamaGenerateRequestModel.class);
assertEqualsAfterUnmarshalling(deserializeRequest, req); assertEqualsAfterUnmarshalling(deserializeRequest, req);
assertEquals(1, deserializeRequest.getOptions().get("mirostat")); assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
} }
@@ -45,7 +45,7 @@ public class TestGenerateRequestSerialization extends AbstractRequestSerializati
OllamaGenerateRequestModel req = OllamaGenerateRequestModel req =
builder.withPrompt("Some prompt").withGetJsonResponse().build(); builder.withPrompt("Some prompt").withGetJsonResponse().build();
String jsonRequest = serializeRequest(req); String jsonRequest = serialize(req);
// no jackson deserialization as format property is not boolean ==> omit as deserialization // no jackson deserialization as format property is not boolean ==> omit as deserialization
// of request is never used in real code anyways // of request is never used in real code anyways
JSONObject jsonObject = new JSONObject(jsonRequest); JSONObject jsonObject = new JSONObject(jsonRequest);

View File

@@ -0,0 +1,42 @@
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
import io.github.amithkoujalgi.ollama4j.core.models.Model;
import org.junit.jupiter.api.Test;
public class TestModelRequestSerialization extends AbstractSerializationTest<Model> {
@Test
public void testDeserializationOfModelResponseWithOffsetTime(){
String serializedTestStringWithOffsetTime = "{\n"
+ "\"name\": \"codellama:13b\",\n"
+ "\"modified_at\": \"2023-11-04T14:56:49.277302595-07:00\",\n"
+ "\"size\": 7365960935,\n"
+ "\"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n"
+ "\"details\": {\n"
+ "\"format\": \"gguf\",\n"
+ "\"family\": \"llama\",\n"
+ "\"families\": null,\n"
+ "\"parameter_size\": \"13B\",\n"
+ "\"quantization_level\": \"Q4_0\"\n"
+ "}}";
deserialize(serializedTestStringWithOffsetTime,Model.class);
}
@Test
public void testDeserializationOfModelResponseWithZuluTime(){
String serializedTestStringWithZuluTimezone = "{\n"
+ "\"name\": \"codellama:13b\",\n"
+ "\"modified_at\": \"2023-11-04T14:56:49.277302595Z\",\n"
+ "\"size\": 7365960935,\n"
+ "\"digest\": \"9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697\",\n"
+ "\"details\": {\n"
+ "\"format\": \"gguf\",\n"
+ "\"family\": \"llama\",\n"
+ "\"families\": null,\n"
+ "\"parameter_size\": \"13B\",\n"
+ "\"quantization_level\": \"Q4_0\"\n"
+ "}}";
deserialize(serializedTestStringWithZuluTimezone,Model.class);
}
}