diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index b03af62..be48a62 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -1386,19 +1386,40 @@ public class OllamaAPI { OllamaGenerateRequest ollamaRequestModel, OllamaGenerateTokenHandler thinkingStreamHandler, OllamaGenerateTokenHandler responseStreamHandler) - throws OllamaBaseException, IOException, InterruptedException { - OllamaGenerateEndpointCaller requestCaller = - new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds); - OllamaResult result; - if (responseStreamHandler != null) { - ollamaRequestModel.setStream(true); - result = - requestCaller.call( - ollamaRequestModel, thinkingStreamHandler, responseStreamHandler); - } else { - result = requestCaller.callSync(ollamaRequestModel); + throws OllamaBaseException { + long startTime = System.currentTimeMillis(); + int statusCode = -1; + Object out = null; + try { + OllamaGenerateEndpointCaller requestCaller = + new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds); + OllamaResult result; + if (responseStreamHandler != null) { + ollamaRequestModel.setStream(true); + result = + requestCaller.call( + ollamaRequestModel, thinkingStreamHandler, responseStreamHandler); + } else { + result = requestCaller.callSync(ollamaRequestModel); + } + statusCode = result.getHttpStatusCode(); + out = result; + return result; + } catch (Exception e) { + throw new OllamaBaseException("Ping failed", e); + } finally { + MetricsRecorder.record( + OllamaGenerateEndpointCaller.endpoint, + ollamaRequestModel.getModel(), + ollamaRequestModel.isRaw(), + ollamaRequestModel.isThink(), + ollamaRequestModel.isStream(), + ollamaRequestModel.getOptions(), + ollamaRequestModel.getFormat(), + startTime, + statusCode, + out); } - return result; } /** diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java index 4cf971b..c72f85d 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java @@ -29,13 +29,12 @@ import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Specialization class for requests - */ +/** Specialization class for requests */ @SuppressWarnings("resource") public class OllamaChatEndpointCaller extends OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaChatEndpointCaller.class); + public static final String endpoint = "/api/chat"; private OllamaChatTokenHandler tokenHandler; @@ -43,19 +42,14 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { super(host, auth, requestTimeoutSeconds); } - @Override - protected String getEndpointSuffix() { - return "/api/chat"; - } - /** - * Parses streamed Response line from ollama chat. - * Using {@link com.fasterxml.jackson.databind.ObjectMapper#readValue(String, TypeReference)} should throw - * {@link IllegalArgumentException} in case of null line or {@link com.fasterxml.jackson.core.JsonParseException} - * in case the JSON Object cannot be parsed to a {@link OllamaChatResponseModel}. Thus, the ResponseModel should - * never be null. + * Parses streamed Response line from ollama chat. Using {@link + * com.fasterxml.jackson.databind.ObjectMapper#readValue(String, TypeReference)} should throw + * {@link IllegalArgumentException} in case of null line or {@link + * com.fasterxml.jackson.core.JsonParseException} in case the JSON Object cannot be parsed to a + * {@link OllamaChatResponseModel}. Thus, the ResponseModel should never be null. * - * @param line streamed line of ollama stream response + * @param line streamed line of ollama stream response * @param responseBuffer Stringbuffer to add latest response message part to * @return TRUE, if ollama-Response has 'done' state */ @@ -97,7 +91,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { throws OllamaBaseException, IOException, InterruptedException { long startTime = System.currentTimeMillis(); HttpClient httpClient = HttpClient.newHttpClient(); - URI uri = URI.create(getHost() + getEndpointSuffix()); + URI uri = URI.create(getHost() + endpoint); HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).POST(body.getBodyPublisher()); HttpRequest request = requestBuilder.build(); @@ -136,7 +130,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { } } MetricsRecorder.record( - getEndpointSuffix(), + endpoint, body.getModel(), false, body.isThink(), @@ -160,8 +154,8 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { } /** - * Handles error status codes and appends error messages to the response buffer. - * Returns true if an error was handled, false otherwise. + * Handles error status codes and appends error messages to the response buffer. Returns true if + * an error was handled, false otherwise. */ private boolean handleErrorStatus(int statusCode, String line, StringBuilder responseBuffer) throws IOException { diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java index 1d73185..01ee916 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaEndpointCaller.java @@ -15,7 +15,7 @@ import java.time.Duration; import lombok.Getter; /** - * Abstract helperclass to call the ollama api server. + * Abstract helper class to call the ollama api server. */ @Getter public abstract class OllamaEndpointCaller { @@ -30,8 +30,6 @@ public abstract class OllamaEndpointCaller { this.requestTimeoutSeconds = requestTimeoutSeconds; } - protected abstract String getEndpointSuffix(); - protected abstract boolean parseResponseAndAddToBuffer( String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer); diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java index 9c3387a..237d5fb 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java @@ -33,6 +33,7 @@ import org.slf4j.LoggerFactory; public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { private static final Logger LOG = LoggerFactory.getLogger(OllamaGenerateEndpointCaller.class); + public static final String endpoint = "/api/generate"; private OllamaGenerateStreamObserver responseStreamObserver; @@ -40,11 +41,6 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { super(host, basicAuth, requestTimeoutSeconds); } - @Override - protected String getEndpointSuffix() { - return "/api/generate"; - } - @Override protected boolean parseResponseAndAddToBuffer( String line, StringBuilder responseBuffer, StringBuilder thinkingBuffer) { @@ -78,12 +74,13 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { } /** - * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response. + * Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for + * the response. * * @param body POST body payload * @return result answer given by the assistant - * @throws OllamaBaseException any response code than 200 has been returned - * @throws IOException in case the responseStream can not be read + * @throws OllamaBaseException any response code than 200 has been returned + * @throws IOException in case the responseStream can not be read * @throws InterruptedException in case the server is not reachable or network issues happen */ @SuppressWarnings("DuplicatedCode") @@ -92,7 +89,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { // Create Request long startTime = System.currentTimeMillis(); HttpClient httpClient = HttpClient.newHttpClient(); - URI uri = URI.create(getHost() + getEndpointSuffix()); + URI uri = URI.create(getHost() + endpoint); HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).POST(body.getBodyPublisher()); HttpRequest request = requestBuilder.build();