mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 11:57:12 +02:00
Merge pull request #28 from AgentSchmecker/feature/advanced_params_for_generate
Adds advanced request parameters to chat/generate requests
This commit is contained in:
commit
6487756764
6
pom.xml
6
pom.xml
@ -174,6 +174,12 @@
|
|||||||
<version>4.1.0</version>
|
<version>4.1.0</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.json</groupId>
|
||||||
|
<artifactId>json</artifactId>
|
||||||
|
<version>20240205</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<distributionManagement>
|
<distributionManagement>
|
||||||
|
@ -6,6 +6,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
|
|||||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
|
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
|
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
|
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
|
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
|
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest;
|
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest;
|
||||||
@ -345,7 +346,7 @@ public class OllamaAPI {
|
|||||||
*/
|
*/
|
||||||
public OllamaResult generate(String model, String prompt, Options options)
|
public OllamaResult generate(String model, String prompt, Options options)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt);
|
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
|
||||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||||
return generateSyncForOllamaRequestModel(ollamaRequestModel);
|
return generateSyncForOllamaRequestModel(ollamaRequestModel);
|
||||||
}
|
}
|
||||||
@ -360,7 +361,7 @@ public class OllamaAPI {
|
|||||||
* @return the ollama async result callback handle
|
* @return the ollama async result callback handle
|
||||||
*/
|
*/
|
||||||
public OllamaAsyncResultCallback generateAsync(String model, String prompt) {
|
public OllamaAsyncResultCallback generateAsync(String model, String prompt) {
|
||||||
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt);
|
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt);
|
||||||
|
|
||||||
URI uri = URI.create(this.host + "/api/generate");
|
URI uri = URI.create(this.host + "/api/generate");
|
||||||
OllamaAsyncResultCallback ollamaAsyncResultCallback =
|
OllamaAsyncResultCallback ollamaAsyncResultCallback =
|
||||||
@ -389,7 +390,7 @@ public class OllamaAPI {
|
|||||||
for (File imageFile : imageFiles) {
|
for (File imageFile : imageFiles) {
|
||||||
images.add(encodeFileToBase64(imageFile));
|
images.add(encodeFileToBase64(imageFile));
|
||||||
}
|
}
|
||||||
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images);
|
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
|
||||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||||
return generateSyncForOllamaRequestModel(ollamaRequestModel);
|
return generateSyncForOllamaRequestModel(ollamaRequestModel);
|
||||||
}
|
}
|
||||||
@ -413,7 +414,7 @@ public class OllamaAPI {
|
|||||||
for (String imageURL : imageURLs) {
|
for (String imageURL : imageURLs) {
|
||||||
images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
|
images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
|
||||||
}
|
}
|
||||||
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images);
|
OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images);
|
||||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||||
return generateSyncForOllamaRequestModel(ollamaRequestModel);
|
return generateSyncForOllamaRequestModel(ollamaRequestModel);
|
||||||
}
|
}
|
||||||
@ -448,7 +449,7 @@ public class OllamaAPI {
|
|||||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||||
*/
|
*/
|
||||||
public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{
|
public OllamaChatResult chat(OllamaChatRequestModel request) throws OllamaBaseException, IOException, InterruptedException{
|
||||||
return chat(request);
|
return chat(request,null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -486,7 +487,7 @@ public class OllamaAPI {
|
|||||||
return Base64.getEncoder().encodeToString(bytes);
|
return Base64.getEncoder().encodeToString(bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel)
|
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||||
return requestCaller.callSync(ollamaRequestModel);
|
return requestCaller.callSync(ollamaRequestModel);
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package io.github.amithkoujalgi.ollama4j.core.models;
|
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||||
|
|
||||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
@ -22,7 +24,7 @@ import lombok.Getter;
|
|||||||
@SuppressWarnings("unused")
|
@SuppressWarnings("unused")
|
||||||
public class OllamaAsyncResultCallback extends Thread {
|
public class OllamaAsyncResultCallback extends Thread {
|
||||||
private final HttpRequest.Builder requestBuilder;
|
private final HttpRequest.Builder requestBuilder;
|
||||||
private final OllamaRequestModel ollamaRequestModel;
|
private final OllamaGenerateRequestModel ollamaRequestModel;
|
||||||
private final Queue<String> queue = new LinkedList<>();
|
private final Queue<String> queue = new LinkedList<>();
|
||||||
private String result;
|
private String result;
|
||||||
private boolean isDone;
|
private boolean isDone;
|
||||||
@ -47,7 +49,7 @@ public class OllamaAsyncResultCallback extends Thread {
|
|||||||
|
|
||||||
public OllamaAsyncResultCallback(
|
public OllamaAsyncResultCallback(
|
||||||
HttpRequest.Builder requestBuilder,
|
HttpRequest.Builder requestBuilder,
|
||||||
OllamaRequestModel ollamaRequestModel,
|
OllamaGenerateRequestModel ollamaRequestModel,
|
||||||
long requestTimeoutSeconds) {
|
long requestTimeoutSeconds) {
|
||||||
this.requestBuilder = requestBuilder;
|
this.requestBuilder = requestBuilder;
|
||||||
this.ollamaRequestModel = ollamaRequestModel;
|
this.ollamaRequestModel = ollamaRequestModel;
|
||||||
@ -87,8 +89,8 @@ public class OllamaAsyncResultCallback extends Thread {
|
|||||||
queue.add(ollamaResponseModel.getError());
|
queue.add(ollamaResponseModel.getError());
|
||||||
responseBuffer.append(ollamaResponseModel.getError());
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
} else {
|
} else {
|
||||||
OllamaResponseModel ollamaResponseModel =
|
OllamaGenerateResponseModel ollamaResponseModel =
|
||||||
Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
|
Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
|
||||||
queue.add(ollamaResponseModel.getResponse());
|
queue.add(ollamaResponseModel.getResponse());
|
||||||
if (!ollamaResponseModel.isDone()) {
|
if (!ollamaResponseModel.isDone()) {
|
||||||
responseBuffer.append(ollamaResponseModel.getResponse());
|
responseBuffer.append(ollamaResponseModel.getResponse());
|
||||||
|
@ -0,0 +1,35 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||||
|
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.utils.BooleanToJsonFormatFlagSerializer;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
public abstract class OllamaCommonRequestModel {
|
||||||
|
|
||||||
|
protected String model;
|
||||||
|
@JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class)
|
||||||
|
@JsonProperty(value = "format")
|
||||||
|
protected Boolean returnFormatJson;
|
||||||
|
protected Map<String, Object> options;
|
||||||
|
protected String template;
|
||||||
|
protected boolean stream;
|
||||||
|
@JsonProperty(value = "keep_alive")
|
||||||
|
protected String keepAlive;
|
||||||
|
|
||||||
|
|
||||||
|
public String toString() {
|
||||||
|
try {
|
||||||
|
return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,39 +0,0 @@
|
|||||||
package io.github.amithkoujalgi.ollama4j.core.models;
|
|
||||||
|
|
||||||
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
||||||
|
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class OllamaRequestModel implements OllamaRequestBody{
|
|
||||||
|
|
||||||
private String model;
|
|
||||||
private String prompt;
|
|
||||||
private List<String> images;
|
|
||||||
private Map<String, Object> options;
|
|
||||||
|
|
||||||
public OllamaRequestModel(String model, String prompt) {
|
|
||||||
this.model = model;
|
|
||||||
this.prompt = prompt;
|
|
||||||
}
|
|
||||||
|
|
||||||
public OllamaRequestModel(String model, String prompt, List<String> images) {
|
|
||||||
this.model = model;
|
|
||||||
this.prompt = prompt;
|
|
||||||
this.images = images;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String toString() {
|
|
||||||
try {
|
|
||||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
|
||||||
} catch (JsonProcessingException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -83,12 +83,12 @@ public class OllamaChatRequestBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public OllamaChatRequestBuilder withOptions(Options options){
|
public OllamaChatRequestBuilder withOptions(Options options){
|
||||||
this.request.setOptions(options);
|
this.request.setOptions(options.getOptionsMap());
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public OllamaChatRequestBuilder withFormat(String format){
|
public OllamaChatRequestBuilder withGetJsonResponse(){
|
||||||
this.request.setFormat(format);
|
this.request.setReturnFormatJson(true);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,47 +1,39 @@
|
|||||||
package io.github.amithkoujalgi.ollama4j.core.models.chat;
|
package io.github.amithkoujalgi.ollama4j.core.models.chat;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel;
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
||||||
|
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
|
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
|
|
||||||
|
|
||||||
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NonNull;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Defines a Request to use against the ollama /api/chat endpoint.
|
* Defines a Request to use against the ollama /api/chat endpoint.
|
||||||
*
|
*
|
||||||
* @see <a
|
* @see <a href=
|
||||||
* href="https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate
|
* "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate
|
||||||
* Chat Completion</a>
|
* Chat Completion</a>
|
||||||
*/
|
*/
|
||||||
@Data
|
@Getter
|
||||||
@AllArgsConstructor
|
@Setter
|
||||||
@RequiredArgsConstructor
|
public class OllamaChatRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody {
|
||||||
public class OllamaChatRequestModel implements OllamaRequestBody {
|
|
||||||
|
|
||||||
@NonNull private String model;
|
private List<OllamaChatMessage> messages;
|
||||||
|
|
||||||
@NonNull private List<OllamaChatMessage> messages;
|
public OllamaChatRequestModel() {}
|
||||||
|
|
||||||
private String format;
|
public OllamaChatRequestModel(String model, List<OllamaChatMessage> messages) {
|
||||||
private Options options;
|
this.model = model;
|
||||||
private String template;
|
this.messages = messages;
|
||||||
private boolean stream;
|
}
|
||||||
private String keepAlive;
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public boolean equals(Object o) {
|
||||||
try {
|
if (!(o instanceof OllamaChatRequestModel)) {
|
||||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
return false;
|
||||||
} catch (JsonProcessingException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return this.toString().equals(o.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,55 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.core.models.generate;
|
||||||
|
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper class for creating {@link io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel}
|
||||||
|
* objects using the builder-pattern.
|
||||||
|
*/
|
||||||
|
public class OllamaGenerateRequestBuilder {
|
||||||
|
|
||||||
|
private OllamaGenerateRequestBuilder(String model, String prompt){
|
||||||
|
request = new OllamaGenerateRequestModel(model, prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
private OllamaGenerateRequestModel request;
|
||||||
|
|
||||||
|
public static OllamaGenerateRequestBuilder getInstance(String model){
|
||||||
|
return new OllamaGenerateRequestBuilder(model,"");
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaGenerateRequestModel build(){
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaGenerateRequestBuilder withPrompt(String prompt){
|
||||||
|
request.setPrompt(prompt);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaGenerateRequestBuilder withGetJsonResponse(){
|
||||||
|
this.request.setReturnFormatJson(true);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaGenerateRequestBuilder withOptions(Options options){
|
||||||
|
this.request.setOptions(options.getOptionsMap());
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaGenerateRequestBuilder withTemplate(String template){
|
||||||
|
this.request.setTemplate(template);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaGenerateRequestBuilder withStreaming(){
|
||||||
|
this.request.setStream(true);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaGenerateRequestBuilder withKeepAlive(String keepAlive){
|
||||||
|
this.request.setKeepAlive(keepAlive);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,46 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.core.models.generate;
|
||||||
|
|
||||||
|
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
public class OllamaGenerateRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody{
|
||||||
|
|
||||||
|
private String prompt;
|
||||||
|
private List<String> images;
|
||||||
|
|
||||||
|
private String system;
|
||||||
|
private String context;
|
||||||
|
private boolean raw;
|
||||||
|
|
||||||
|
public OllamaGenerateRequestModel() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaGenerateRequestModel(String model, String prompt) {
|
||||||
|
this.model = model;
|
||||||
|
this.prompt = prompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaGenerateRequestModel(String model, String prompt, List<String> images) {
|
||||||
|
this.model = model;
|
||||||
|
this.prompt = prompt;
|
||||||
|
this.images = images;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (!(o instanceof OllamaGenerateRequestModel)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.toString().equals(o.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package io.github.amithkoujalgi.ollama4j.core.models;
|
package io.github.amithkoujalgi.ollama4j.core.models.generate;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
@ -8,7 +8,7 @@ import lombok.Data;
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||||
public class OllamaResponseModel {
|
public class OllamaGenerateResponseModel {
|
||||||
private String model;
|
private String model;
|
||||||
private @JsonProperty("created_at") String createdAt;
|
private @JsonProperty("created_at") String createdAt;
|
||||||
private String response;
|
private String response;
|
@ -6,7 +6,7 @@ import org.slf4j.LoggerFactory;
|
|||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
|
import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResponseModel;
|
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel;
|
||||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||||
|
|
||||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
|
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
|
||||||
@ -25,7 +25,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{
|
|||||||
@Override
|
@Override
|
||||||
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
|
protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) {
|
||||||
try {
|
try {
|
||||||
OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class);
|
OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class);
|
||||||
responseBuffer.append(ollamaResponseModel.getResponse());
|
responseBuffer.append(ollamaResponseModel.getResponse());
|
||||||
return ollamaResponseModel.isDone();
|
return ollamaResponseModel.isDone();
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
|
@ -0,0 +1,21 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.core.utils;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonGenerator;
|
||||||
|
import com.fasterxml.jackson.databind.JsonSerializer;
|
||||||
|
import com.fasterxml.jackson.databind.SerializerProvider;
|
||||||
|
|
||||||
|
public class BooleanToJsonFormatFlagSerializer extends JsonSerializer<Boolean>{
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void serialize(Boolean value, JsonGenerator gen, SerializerProvider serializers) throws IOException {
|
||||||
|
gen.writeString("json");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isEmpty(SerializerProvider provider,Boolean value){
|
||||||
|
return !value;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,8 +1,6 @@
|
|||||||
package io.github.amithkoujalgi.ollama4j.core.utils;
|
package io.github.amithkoujalgi.ollama4j.core.utils;
|
||||||
|
|
||||||
import java.io.ByteArrayOutputStream;
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.ObjectOutputStream;
|
|
||||||
import java.util.Base64;
|
import java.util.Base64;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
|
|
||||||
@ -20,11 +18,4 @@ public class FileToBase64Serializer extends JsonSerializer<Collection<byte[]>> {
|
|||||||
}
|
}
|
||||||
jsonGenerator.writeEndArray();
|
jsonGenerator.writeEndArray();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static byte[] serialize(Object obj) throws IOException {
|
|
||||||
ByteArrayOutputStream out = new ByteArrayOutputStream();
|
|
||||||
ObjectOutputStream os = new ObjectOutputStream(out);
|
|
||||||
os.writeObject(obj);
|
|
||||||
return out.toByteArray();
|
|
||||||
}
|
|
||||||
}
|
}
|
@ -0,0 +1,106 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.json.JSONObject;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||||
|
|
||||||
|
public class TestChatRequestSerialization {
|
||||||
|
|
||||||
|
private OllamaChatRequestBuilder builder;
|
||||||
|
|
||||||
|
private ObjectMapper mapper = Utils.getObjectMapper();
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
public void init() {
|
||||||
|
builder = OllamaChatRequestBuilder.getInstance("DummyModel");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRequestOnlyMandatoryFields() {
|
||||||
|
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
|
||||||
|
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
|
||||||
|
String jsonRequest = serializeRequest(req);
|
||||||
|
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRequestMultipleMessages() {
|
||||||
|
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.SYSTEM, "System prompt")
|
||||||
|
.withMessage(OllamaChatMessageRole.USER, "Some prompt")
|
||||||
|
.build();
|
||||||
|
String jsonRequest = serializeRequest(req);
|
||||||
|
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRequestWithMessageAndImage() {
|
||||||
|
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
|
||||||
|
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
|
||||||
|
String jsonRequest = serializeRequest(req);
|
||||||
|
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRequestWithOptions() {
|
||||||
|
OptionsBuilder b = new OptionsBuilder();
|
||||||
|
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
|
||||||
|
.withOptions(b.setMirostat(1).build()).build();
|
||||||
|
|
||||||
|
String jsonRequest = serializeRequest(req);
|
||||||
|
OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest);
|
||||||
|
assertEqualsAfterUnmarshalling(deserializeRequest, req);
|
||||||
|
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWithJsonFormat() {
|
||||||
|
OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
|
||||||
|
.withGetJsonResponse().build();
|
||||||
|
|
||||||
|
String jsonRequest = serializeRequest(req);
|
||||||
|
// no jackson deserialization as format property is not boolean ==> omit as deserialization
|
||||||
|
// of request is never used in real code anyways
|
||||||
|
JSONObject jsonObject = new JSONObject(jsonRequest);
|
||||||
|
String requestFormatProperty = jsonObject.getString("format");
|
||||||
|
assertEquals("json", requestFormatProperty);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String serializeRequest(OllamaChatRequestModel req) {
|
||||||
|
try {
|
||||||
|
return mapper.writeValueAsString(req);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
fail("Could not serialize request!", e);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private OllamaChatRequestModel deserializeRequest(String jsonRequest) {
|
||||||
|
try {
|
||||||
|
return mapper.readValue(jsonRequest, OllamaChatRequestModel.class);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
fail("Could not deserialize jsonRequest!", e);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest,
|
||||||
|
OllamaChatRequestModel req) {
|
||||||
|
assertEquals(req, unmarshalledRequest);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,85 @@
|
|||||||
|
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
|
|
||||||
|
import org.json.JSONObject;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
|
||||||
|
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
|
||||||
|
|
||||||
|
public class TestGenerateRequestSerialization {
|
||||||
|
|
||||||
|
private OllamaGenerateRequestBuilder builder;
|
||||||
|
|
||||||
|
private ObjectMapper mapper = Utils.getObjectMapper();
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
public void init() {
|
||||||
|
builder = OllamaGenerateRequestBuilder.getInstance("DummyModel");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRequestOnlyMandatoryFields() {
|
||||||
|
OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build();
|
||||||
|
|
||||||
|
String jsonRequest = serializeRequest(req);
|
||||||
|
assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRequestWithOptions() {
|
||||||
|
OptionsBuilder b = new OptionsBuilder();
|
||||||
|
OllamaGenerateRequestModel req =
|
||||||
|
builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build();
|
||||||
|
|
||||||
|
String jsonRequest = serializeRequest(req);
|
||||||
|
OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest);
|
||||||
|
assertEqualsAfterUnmarshalling(deserializeRequest, req);
|
||||||
|
assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWithJsonFormat() {
|
||||||
|
OllamaGenerateRequestModel req =
|
||||||
|
builder.withPrompt("Some prompt").withGetJsonResponse().build();
|
||||||
|
|
||||||
|
String jsonRequest = serializeRequest(req);
|
||||||
|
// no jackson deserialization as format property is not boolean ==> omit as deserialization
|
||||||
|
// of request is never used in real code anyways
|
||||||
|
JSONObject jsonObject = new JSONObject(jsonRequest);
|
||||||
|
String requestFormatProperty = jsonObject.getString("format");
|
||||||
|
assertEquals("json", requestFormatProperty);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String serializeRequest(OllamaGenerateRequestModel req) {
|
||||||
|
try {
|
||||||
|
return mapper.writeValueAsString(req);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
fail("Could not serialize request!", e);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) {
|
||||||
|
try {
|
||||||
|
return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
fail("Could not deserialize jsonRequest!", e);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest,
|
||||||
|
OllamaGenerateRequestModel req) {
|
||||||
|
assertEquals(req, unmarshalledRequest);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user