diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index e48add1..7e82ee8 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -6,6 +6,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest; @@ -345,7 +346,7 @@ public class OllamaAPI { */ public OllamaResult generate(String model, String prompt, Options options) throws OllamaBaseException, IOException, InterruptedException { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); + OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); ollamaRequestModel.setOptions(options.getOptionsMap()); return generateSyncForOllamaRequestModel(ollamaRequestModel); } @@ -360,7 +361,7 @@ public class OllamaAPI { * @return the ollama async result callback handle */ public OllamaAsyncResultCallback generateAsync(String model, String prompt) { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); + OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt); URI uri = URI.create(this.host + "/api/generate"); OllamaAsyncResultCallback ollamaAsyncResultCallback = @@ -389,7 +390,7 @@ public class OllamaAPI { for (File imageFile : imageFiles) { images.add(encodeFileToBase64(imageFile)); } - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); + OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); return generateSyncForOllamaRequestModel(ollamaRequestModel); } @@ -413,7 +414,7 @@ public class OllamaAPI { for (String imageURL : imageURLs) { images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL))); } - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); + OllamaGenerateRequestModel ollamaRequestModel = new OllamaGenerateRequestModel(model, prompt, images); ollamaRequestModel.setOptions(options.getOptionsMap()); return generateSyncForOllamaRequestModel(ollamaRequestModel); } @@ -486,7 +487,7 @@ public class OllamaAPI { return Base64.getEncoder().encodeToString(bytes); } - private OllamaResult generateSyncForOllamaRequestModel(OllamaRequestModel ollamaRequestModel) + private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequestModel ollamaRequestModel) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose); return requestCaller.callSync(ollamaRequestModel); diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java index 74b8c49..136f1c6 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java @@ -1,6 +1,8 @@ package io.github.amithkoujalgi.ollama4j.core.models; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import java.io.BufferedReader; import java.io.IOException; @@ -22,7 +24,7 @@ import lombok.Getter; @SuppressWarnings("unused") public class OllamaAsyncResultCallback extends Thread { private final HttpRequest.Builder requestBuilder; - private final OllamaRequestModel ollamaRequestModel; + private final OllamaGenerateRequestModel ollamaRequestModel; private final Queue queue = new LinkedList<>(); private String result; private boolean isDone; @@ -47,7 +49,7 @@ public class OllamaAsyncResultCallback extends Thread { public OllamaAsyncResultCallback( HttpRequest.Builder requestBuilder, - OllamaRequestModel ollamaRequestModel, + OllamaGenerateRequestModel ollamaRequestModel, long requestTimeoutSeconds) { this.requestBuilder = requestBuilder; this.ollamaRequestModel = ollamaRequestModel; @@ -87,8 +89,8 @@ public class OllamaAsyncResultCallback extends Thread { queue.add(ollamaResponseModel.getError()); responseBuffer.append(ollamaResponseModel.getError()); } else { - OllamaResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); + OllamaGenerateResponseModel ollamaResponseModel = + Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); queue.add(ollamaResponseModel.getResponse()); if (!ollamaResponseModel.isDone()) { responseBuffer.append(ollamaResponseModel.getResponse()); diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java new file mode 100644 index 0000000..f554638 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaCommonRequestModel.java @@ -0,0 +1,26 @@ +package io.github.amithkoujalgi.ollama4j.core.models; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + +import io.github.amithkoujalgi.ollama4j.core.utils.BooleanToJsonFormatFlagSerializer; +import lombok.Data; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; + +@Data +@RequiredArgsConstructor +public abstract class OllamaCommonRequestModel { + + @NonNull + protected String model; + @JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class) + protected boolean returnFormatJson; + protected Map options; + protected String template; + protected boolean stream; + @JsonProperty(value = "keep_alive") + protected String keepAlive; +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java index 5abbcde..e07722f 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestBuilder.java @@ -83,12 +83,12 @@ public class OllamaChatRequestBuilder { } public OllamaChatRequestBuilder withOptions(Options options){ - this.request.setOptions(options); + this.request.setOptions(options.getOptionsMap()); return this; } - public OllamaChatRequestBuilder withFormat(String format){ - this.request.setFormat(format); + public OllamaChatRequestBuilder withGetJsonResponse(){ + this.request.setReturnFormatJson(true); return this; } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java index 2f947a2..82c9010 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/chat/OllamaChatRequestModel.java @@ -4,15 +4,14 @@ import java.util.List; import com.fasterxml.jackson.core.JsonProcessingException; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; -import io.github.amithkoujalgi.ollama4j.core.utils.Options; import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; -import lombok.AllArgsConstructor; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NonNull; -import lombok.RequiredArgsConstructor; /** * Defines a Request to use against the ollama /api/chat endpoint. @@ -22,19 +21,15 @@ import lombok.RequiredArgsConstructor; * Chat Completion */ @Data -@AllArgsConstructor -@RequiredArgsConstructor -public class OllamaChatRequestModel implements OllamaRequestBody { - - @NonNull private String model; +@EqualsAndHashCode(callSuper = true) +public class OllamaChatRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody { @NonNull private List messages; - private String format; - private Options options; - private String template; - private boolean stream; - private String keepAlive; + public OllamaChatRequestModel(String model,List messages){ + super(model); + this.messages = messages; + } @Override public String toString() { diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java new file mode 100644 index 0000000..48b4d18 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestBuilder.java @@ -0,0 +1,55 @@ +package io.github.amithkoujalgi.ollama4j.core.models.generate; + +import io.github.amithkoujalgi.ollama4j.core.utils.Options; + +/** + * Helper class for creating {@link io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel} + * objects using the builder-pattern. + */ +public class OllamaGenerateRequestBuilder { + + private OllamaGenerateRequestBuilder(String model, String prompt){ + request = new OllamaGenerateRequestModel(model, prompt); + } + + private OllamaGenerateRequestModel request; + + public static OllamaGenerateRequestBuilder getInstance(String model){ + return new OllamaGenerateRequestBuilder(model,""); + } + + public OllamaGenerateRequestModel build(){ + return request; + } + + public OllamaGenerateRequestBuilder withPrompt(String prompt){ + request.setPrompt(prompt); + return this; + } + + public OllamaGenerateRequestBuilder withGetJsonResponse(){ + this.request.setReturnFormatJson(true); + return this; + } + + public OllamaGenerateRequestBuilder withOptions(Options options){ + this.request.setOptions(options.getOptionsMap()); + return this; + } + + public OllamaGenerateRequestBuilder withTemplate(String template){ + this.request.setTemplate(template); + return this; + } + + public OllamaGenerateRequestBuilder withStreaming(){ + this.request.setStream(true); + return this; + } + + public OllamaGenerateRequestBuilder withKeepAlive(String keepAlive){ + this.request.setKeepAlive(keepAlive); + return this; + } + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java similarity index 50% rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java index 9c88698..861be1a 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateRequestModel.java @@ -1,30 +1,37 @@ -package io.github.amithkoujalgi.ollama4j.core.models; +package io.github.amithkoujalgi.ollama4j.core.models.generate; import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; import com.fasterxml.jackson.core.JsonProcessingException; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaCommonRequestModel; import io.github.amithkoujalgi.ollama4j.core.utils.OllamaRequestBody; import java.util.List; -import java.util.Map; import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NonNull; @Data -public class OllamaRequestModel implements OllamaRequestBody{ +@EqualsAndHashCode(callSuper = true) +public class OllamaGenerateRequestModel extends OllamaCommonRequestModel implements OllamaRequestBody{ - private String model; + @NonNull private String prompt; private List images; - private Map options; - public OllamaRequestModel(String model, String prompt) { - this.model = model; + private String system; + private String context; + private boolean raw; + + + public OllamaGenerateRequestModel(String model, String prompt) { + super(model); this.prompt = prompt; } - public OllamaRequestModel(String model, String prompt, List images) { - this.model = model; + public OllamaGenerateRequestModel(String model, String prompt, List images) { + super(model); this.prompt = prompt; this.images = images; } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateResponseModel.java similarity index 88% rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateResponseModel.java index 9481224..a575a7a 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/generate/OllamaGenerateResponseModel.java @@ -1,4 +1,4 @@ -package io.github.amithkoujalgi.ollama4j.core.models; +package io.github.amithkoujalgi.ollama4j.core.models.generate; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; @@ -8,7 +8,7 @@ import lombok.Data; @Data @JsonIgnoreProperties(ignoreUnknown = true) -public class OllamaResponseModel { +public class OllamaGenerateResponseModel { private String model; private @JsonProperty("created_at") String createdAt; private String response; diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java index 8d54db3..ba55159 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java @@ -6,7 +6,7 @@ import org.slf4j.LoggerFactory; import com.fasterxml.jackson.core.JsonProcessingException; import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; -import io.github.amithkoujalgi.ollama4j.core.models.OllamaResponseModel; +import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateResponseModel; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ @@ -25,7 +25,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ @Override protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { try { - OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); + OllamaGenerateResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaGenerateResponseModel.class); responseBuffer.append(ollamaResponseModel.getResponse()); return ollamaResponseModel.isDone(); } catch (JsonProcessingException e) { diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java new file mode 100644 index 0000000..7513fc3 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/BooleanToJsonFormatFlagSerializer.java @@ -0,0 +1,18 @@ +package io.github.amithkoujalgi.ollama4j.core.utils; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; + +public class BooleanToJsonFormatFlagSerializer extends JsonSerializer{ + + @Override + public void serialize(Boolean value, JsonGenerator gen, SerializerProvider serializers) throws IOException { + if(value){ + gen.writeString("json"); + } + } + +}