diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java new file mode 100644 index 0000000..eb06c37 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatEndpointCaller.java @@ -0,0 +1,44 @@ +package io.github.amithkoujalgi.ollama4j.core.models.request; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.JsonProcessingException; + +import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; +import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResponseModel; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; + +/** + * Specialization class for requests + */ +public class OllamaChatEndpointCaller extends OllamaEndpointCaller{ + + private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); + + public OllamaChatEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { + super(host, basicAuth, requestTimeoutSeconds, verbose); + } + + @Override + protected String getEndpointSuffix() { + return "/api/chat"; + } + + @Override + protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { + try { + OllamaChatResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class); + responseBuffer.append(ollamaResponseModel.getMessage().getContent()); + return ollamaResponseModel.isDone(); + } catch (JsonProcessingException e) { + LOG.error("Error parsing the Ollama chat response!",e); + return true; + } + } + + + + + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatRequestCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatRequestCaller.java deleted file mode 100644 index b08c507..0000000 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaChatRequestCaller.java +++ /dev/null @@ -1,16 +0,0 @@ -package io.github.amithkoujalgi.ollama4j.core.models.request; - -import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; - -public class OllamaChatRequestCaller extends OllamaServerCaller{ - - public OllamaChatRequestCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { - super(host, basicAuth, requestTimeoutSeconds, verbose); - } - - @Override - protected String getEndpointSuffix() { - return "/api/generate"; - } - -} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaServerCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java similarity index 92% rename from src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaServerCaller.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java index 6f9f27b..d99499f 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaServerCaller.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaEndpointCaller.java @@ -27,7 +27,7 @@ import io.github.amithkoujalgi.ollama4j.core.utils.Utils; /** * Abstract helperclass to call the ollama api server. */ -public abstract class OllamaServerCaller { +public abstract class OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class); @@ -36,7 +36,7 @@ public abstract class OllamaServerCaller { private long requestTimeoutSeconds; private boolean verbose; - public OllamaServerCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { + public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { this.host = host; this.basicAuth = basicAuth; this.requestTimeoutSeconds = requestTimeoutSeconds; @@ -44,6 +44,9 @@ public abstract class OllamaServerCaller { } protected abstract String getEndpointSuffix(); + + protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer); + /** * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response. @@ -89,11 +92,10 @@ public abstract class OllamaServerCaller { .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); responseBuffer.append(ollamaResponseModel.getError()); } else { - OllamaResponseModel ollamaResponseModel = - Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); - if (!ollamaResponseModel.isDone()) { - responseBuffer.append(ollamaResponseModel.getResponse()); - } + boolean finished = parseResponseAndAddToBuffer(line,responseBuffer); + if (finished) { + break; + } } } } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java new file mode 100644 index 0000000..8d54db3 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateEndpointCaller.java @@ -0,0 +1,40 @@ +package io.github.amithkoujalgi.ollama4j.core.models.request; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.JsonProcessingException; + +import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; +import io.github.amithkoujalgi.ollama4j.core.models.OllamaResponseModel; +import io.github.amithkoujalgi.ollama4j.core.utils.Utils; + +public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller{ + + private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class); + + public OllamaGenerateEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { + super(host, basicAuth, requestTimeoutSeconds, verbose); + } + + @Override + protected String getEndpointSuffix() { + return "/api/generate"; + } + + @Override + protected boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer) { + try { + OllamaResponseModel ollamaResponseModel = Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); + responseBuffer.append(ollamaResponseModel.getResponse()); + return ollamaResponseModel.isDone(); + } catch (JsonProcessingException e) { + LOG.error("Error parsing the Ollama chat response!",e); + return true; + } + } + + + + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateRequestCaller.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateRequestCaller.java deleted file mode 100644 index cc316db..0000000 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/request/OllamaGenerateRequestCaller.java +++ /dev/null @@ -1,18 +0,0 @@ -package io.github.amithkoujalgi.ollama4j.core.models.request; - -import io.github.amithkoujalgi.ollama4j.core.models.BasicAuth; - -public class OllamaGenerateRequestCaller extends OllamaServerCaller{ - - public OllamaGenerateRequestCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) { - super(host, basicAuth, requestTimeoutSeconds, verbose); - } - - @Override - protected String getEndpointSuffix() { - return "/api/generate"; - } - - - -}