mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 03:47:13 +02:00
extends ollamaChatResult to have full access to OllamaChatResult
This commit is contained in:
parent
12bb10392e
commit
25694a8bc9
@ -602,7 +602,7 @@ public class OllamaAPI {
|
|||||||
OllamaResult result = generate(model, prompt, raw, options, null);
|
OllamaResult result = generate(model, prompt, raw, options, null);
|
||||||
toolResult.setModelResult(result);
|
toolResult.setModelResult(result);
|
||||||
|
|
||||||
String toolsResponse = result.getContent();
|
String toolsResponse = result.getResponse();
|
||||||
if (toolsResponse.contains("[TOOL_CALLS]")) {
|
if (toolsResponse.contains("[TOOL_CALLS]")) {
|
||||||
toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
|
toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
|
||||||
}
|
}
|
||||||
@ -767,7 +767,7 @@ public class OllamaAPI {
|
|||||||
*/
|
*/
|
||||||
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||||
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, basicAuth, requestTimeoutSeconds, verbose);
|
||||||
OllamaResult result;
|
OllamaChatResult result;
|
||||||
|
|
||||||
// add all registered tools to Request
|
// add all registered tools to Request
|
||||||
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
|
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
|
||||||
@ -779,7 +779,7 @@ public class OllamaAPI {
|
|||||||
result = requestCaller.callSync(request);
|
result = requestCaller.callSync(request);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new OllamaChatResult(result.getContent(), result.getResponseTime(), result.getHttpStatusCode(), request.getMessages());
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void registerTool(Tools.ToolSpecification toolSpecification) {
|
public void registerTool(Tools.ToolSpecification toolSpecification) {
|
||||||
|
@ -10,6 +10,7 @@ import java.io.IOException;
|
|||||||
import java.net.URISyntaxException;
|
import java.net.URISyntaxException;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@ -38,6 +39,10 @@ public class OllamaChatRequestBuilder {
|
|||||||
request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
|
request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content){
|
||||||
|
return withMessage(role,content, Collections.emptyList());
|
||||||
|
}
|
||||||
|
|
||||||
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
|
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls,List<File> images) {
|
||||||
List<OllamaChatMessage> messages = this.request.getMessages();
|
List<OllamaChatMessage> messages = this.request.getMessages();
|
||||||
|
|
||||||
|
@ -2,28 +2,40 @@ package io.github.ollama4j.models.chat;
|
|||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import io.github.ollama4j.models.response.OllamaResult;
|
import io.github.ollama4j.models.response.OllamaResult;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
|
* Specific chat-API result that contains the chat history sent to the model and appends the answer as {@link OllamaChatResult} given by the
|
||||||
* {@link OllamaChatMessageRole#ASSISTANT} role.
|
* {@link OllamaChatMessageRole#ASSISTANT} role.
|
||||||
*/
|
*/
|
||||||
public class OllamaChatResult extends OllamaResult {
|
@Getter
|
||||||
|
public class OllamaChatResult {
|
||||||
|
|
||||||
|
|
||||||
private List<OllamaChatMessage> chatHistory;
|
private List<OllamaChatMessage> chatHistory;
|
||||||
|
|
||||||
public OllamaChatResult(String response, long responseTime, int httpStatusCode, List<OllamaChatMessage> chatHistory) {
|
private OllamaChatResponseModel response;
|
||||||
super(response, responseTime, httpStatusCode);
|
|
||||||
|
public OllamaChatResult(OllamaChatResponseModel response, List<OllamaChatMessage> chatHistory) {
|
||||||
this.chatHistory = chatHistory;
|
this.chatHistory = chatHistory;
|
||||||
|
this.response = response;
|
||||||
appendAnswerToChatHistory(response);
|
appendAnswerToChatHistory(response);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<OllamaChatMessage> getChatHistory() {
|
private void appendAnswerToChatHistory(OllamaChatResponseModel response) {
|
||||||
return chatHistory;
|
this.chatHistory.add(response.getMessage());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void appendAnswerToChatHistory(String answer) {
|
@Override
|
||||||
OllamaChatMessage assistantMessage = new OllamaChatMessage(OllamaChatMessageRole.ASSISTANT, answer);
|
public String toString() {
|
||||||
this.chatHistory.add(assistantMessage);
|
try {
|
||||||
|
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,9 @@ import com.fasterxml.jackson.core.JsonProcessingException;
|
|||||||
import com.fasterxml.jackson.core.type.TypeReference;
|
import com.fasterxml.jackson.core.type.TypeReference;
|
||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
|
import io.github.ollama4j.models.chat.OllamaChatMessage;
|
||||||
|
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||||
|
import io.github.ollama4j.models.chat.OllamaChatResult;
|
||||||
|
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
||||||
import io.github.ollama4j.models.response.OllamaResult;
|
import io.github.ollama4j.models.response.OllamaResult;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
|
import io.github.ollama4j.models.chat.OllamaChatResponseModel;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
|
import io.github.ollama4j.models.chat.OllamaChatStreamObserver;
|
||||||
@ -13,7 +16,15 @@ import io.github.ollama4j.utils.Utils;
|
|||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.http.HttpClient;
|
||||||
|
import java.net.http.HttpRequest;
|
||||||
|
import java.net.http.HttpResponse;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specialization class for requests
|
* Specialization class for requests
|
||||||
@ -64,9 +75,68 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
public OllamaChatResult call(OllamaChatRequest body, OllamaStreamHandler streamHandler)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
streamObserver = new OllamaChatStreamObserver(streamHandler);
|
streamObserver = new OllamaChatStreamObserver(streamHandler);
|
||||||
return super.callSync(body);
|
return callSync(body);
|
||||||
|
}
|
||||||
|
|
||||||
|
public OllamaChatResult callSync(OllamaChatRequest body) throws OllamaBaseException, IOException, InterruptedException {
|
||||||
|
// Create Request
|
||||||
|
HttpClient httpClient = HttpClient.newHttpClient();
|
||||||
|
URI uri = URI.create(getHost() + getEndpointSuffix());
|
||||||
|
HttpRequest.Builder requestBuilder =
|
||||||
|
getRequestBuilderDefault(uri)
|
||||||
|
.POST(
|
||||||
|
body.getBodyPublisher());
|
||||||
|
HttpRequest request = requestBuilder.build();
|
||||||
|
if (isVerbose()) LOG.info("Asking model: " + body.toString());
|
||||||
|
HttpResponse<InputStream> response =
|
||||||
|
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||||
|
|
||||||
|
int statusCode = response.statusCode();
|
||||||
|
InputStream responseBodyStream = response.body();
|
||||||
|
StringBuilder responseBuffer = new StringBuilder();
|
||||||
|
OllamaChatResponseModel ollamaChatResponseModel = null;
|
||||||
|
try (BufferedReader reader =
|
||||||
|
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||||
|
|
||||||
|
String line;
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
if (statusCode == 404) {
|
||||||
|
LOG.warn("Status code: 404 (Not Found)");
|
||||||
|
OllamaErrorResponse ollamaResponseModel =
|
||||||
|
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else if (statusCode == 401) {
|
||||||
|
LOG.warn("Status code: 401 (Unauthorized)");
|
||||||
|
OllamaErrorResponse ollamaResponseModel =
|
||||||
|
Utils.getObjectMapper()
|
||||||
|
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else if (statusCode == 400) {
|
||||||
|
LOG.warn("Status code: 400 (Bad Request)");
|
||||||
|
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
|
||||||
|
OllamaErrorResponse.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else {
|
||||||
|
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
||||||
|
ollamaChatResponseModel = Utils.getObjectMapper().readValue(line, OllamaChatResponseModel.class);
|
||||||
|
if (finished && body.stream) {
|
||||||
|
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (statusCode != 200) {
|
||||||
|
LOG.error("Status code " + statusCode);
|
||||||
|
throw new OllamaBaseException(responseBuffer.toString());
|
||||||
|
} else {
|
||||||
|
OllamaChatResult ollamaResult =
|
||||||
|
new OllamaChatResult(ollamaChatResponseModel,body.getMessages());
|
||||||
|
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
|
||||||
|
return ollamaResult;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import io.github.ollama4j.models.response.OllamaErrorResponse;
|
|||||||
import io.github.ollama4j.models.response.OllamaResult;
|
import io.github.ollama4j.models.response.OllamaResult;
|
||||||
import io.github.ollama4j.utils.OllamaRequestBody;
|
import io.github.ollama4j.utils.OllamaRequestBody;
|
||||||
import io.github.ollama4j.utils.Utils;
|
import io.github.ollama4j.utils.Utils;
|
||||||
|
import lombok.Getter;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
@ -24,14 +25,15 @@ import java.util.Base64;
|
|||||||
/**
|
/**
|
||||||
* Abstract helperclass to call the ollama api server.
|
* Abstract helperclass to call the ollama api server.
|
||||||
*/
|
*/
|
||||||
|
@Getter
|
||||||
public abstract class OllamaEndpointCaller {
|
public abstract class OllamaEndpointCaller {
|
||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
|
private static final Logger LOG = LoggerFactory.getLogger(OllamaAPI.class);
|
||||||
|
|
||||||
private String host;
|
private final String host;
|
||||||
private BasicAuth basicAuth;
|
private final BasicAuth basicAuth;
|
||||||
private long requestTimeoutSeconds;
|
private final long requestTimeoutSeconds;
|
||||||
private boolean verbose;
|
private final boolean verbose;
|
||||||
|
|
||||||
public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
public OllamaEndpointCaller(String host, BasicAuth basicAuth, long requestTimeoutSeconds, boolean verbose) {
|
||||||
this.host = host;
|
this.host = host;
|
||||||
@ -45,80 +47,13 @@ public abstract class OllamaEndpointCaller {
|
|||||||
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
|
protected abstract boolean parseResponseAndAddToBuffer(String line, StringBuilder responseBuffer);
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
|
|
||||||
*
|
|
||||||
* @param body POST body payload
|
|
||||||
* @return result answer given by the assistant
|
|
||||||
* @throws OllamaBaseException any response code than 200 has been returned
|
|
||||||
* @throws IOException in case the responseStream can not be read
|
|
||||||
* @throws InterruptedException in case the server is not reachable or network issues happen
|
|
||||||
*/
|
|
||||||
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
|
|
||||||
// Create Request
|
|
||||||
long startTime = System.currentTimeMillis();
|
|
||||||
HttpClient httpClient = HttpClient.newHttpClient();
|
|
||||||
URI uri = URI.create(this.host + getEndpointSuffix());
|
|
||||||
HttpRequest.Builder requestBuilder =
|
|
||||||
getRequestBuilderDefault(uri)
|
|
||||||
.POST(
|
|
||||||
body.getBodyPublisher());
|
|
||||||
HttpRequest request = requestBuilder.build();
|
|
||||||
if (this.verbose) LOG.info("Asking model: " + body.toString());
|
|
||||||
HttpResponse<InputStream> response =
|
|
||||||
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
|
||||||
|
|
||||||
int statusCode = response.statusCode();
|
|
||||||
InputStream responseBodyStream = response.body();
|
|
||||||
StringBuilder responseBuffer = new StringBuilder();
|
|
||||||
try (BufferedReader reader =
|
|
||||||
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
|
||||||
String line;
|
|
||||||
while ((line = reader.readLine()) != null) {
|
|
||||||
if (statusCode == 404) {
|
|
||||||
LOG.warn("Status code: 404 (Not Found)");
|
|
||||||
OllamaErrorResponse ollamaResponseModel =
|
|
||||||
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
|
||||||
responseBuffer.append(ollamaResponseModel.getError());
|
|
||||||
} else if (statusCode == 401) {
|
|
||||||
LOG.warn("Status code: 401 (Unauthorized)");
|
|
||||||
OllamaErrorResponse ollamaResponseModel =
|
|
||||||
Utils.getObjectMapper()
|
|
||||||
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
|
|
||||||
responseBuffer.append(ollamaResponseModel.getError());
|
|
||||||
} else if (statusCode == 400) {
|
|
||||||
LOG.warn("Status code: 400 (Bad Request)");
|
|
||||||
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
|
|
||||||
OllamaErrorResponse.class);
|
|
||||||
responseBuffer.append(ollamaResponseModel.getError());
|
|
||||||
} else {
|
|
||||||
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
|
||||||
if (finished) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (statusCode != 200) {
|
|
||||||
LOG.error("Status code " + statusCode);
|
|
||||||
throw new OllamaBaseException(responseBuffer.toString());
|
|
||||||
} else {
|
|
||||||
long endTime = System.currentTimeMillis();
|
|
||||||
OllamaResult ollamaResult =
|
|
||||||
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
|
|
||||||
if (verbose) LOG.info("Model response: " + ollamaResult);
|
|
||||||
return ollamaResult;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get default request builder.
|
* Get default request builder.
|
||||||
*
|
*
|
||||||
* @param uri URI to get a HttpRequest.Builder
|
* @param uri URI to get a HttpRequest.Builder
|
||||||
* @return HttpRequest.Builder
|
* @return HttpRequest.Builder
|
||||||
*/
|
*/
|
||||||
private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
|
protected HttpRequest.Builder getRequestBuilderDefault(URI uri) {
|
||||||
HttpRequest.Builder requestBuilder =
|
HttpRequest.Builder requestBuilder =
|
||||||
HttpRequest.newBuilder(uri)
|
HttpRequest.newBuilder(uri)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
@ -134,7 +69,7 @@ public abstract class OllamaEndpointCaller {
|
|||||||
*
|
*
|
||||||
* @return basic authentication header value (encoded credentials)
|
* @return basic authentication header value (encoded credentials)
|
||||||
*/
|
*/
|
||||||
private String getBasicAuthHeaderValue() {
|
protected String getBasicAuthHeaderValue() {
|
||||||
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
|
String credentialsToEncode = this.basicAuth.getUsername() + ":" + this.basicAuth.getPassword();
|
||||||
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
|
return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes());
|
||||||
}
|
}
|
||||||
@ -144,7 +79,7 @@ public abstract class OllamaEndpointCaller {
|
|||||||
*
|
*
|
||||||
* @return true when Basic Auth credentials set
|
* @return true when Basic Auth credentials set
|
||||||
*/
|
*/
|
||||||
private boolean isBasicAuthCredentialsSet() {
|
protected boolean isBasicAuthCredentialsSet() {
|
||||||
return this.basicAuth != null;
|
return this.basicAuth != null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package io.github.ollama4j.models.request;
|
|||||||
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||||
|
import io.github.ollama4j.models.response.OllamaErrorResponse;
|
||||||
import io.github.ollama4j.models.response.OllamaResult;
|
import io.github.ollama4j.models.response.OllamaResult;
|
||||||
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
|
import io.github.ollama4j.models.generate.OllamaGenerateResponseModel;
|
||||||
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
|
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
|
||||||
@ -11,7 +12,15 @@ import io.github.ollama4j.utils.Utils;
|
|||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.http.HttpClient;
|
||||||
|
import java.net.http.HttpRequest;
|
||||||
|
import java.net.http.HttpResponse;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
|
||||||
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
||||||
|
|
||||||
@ -46,6 +55,73 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
|
|||||||
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
public OllamaResult call(OllamaRequestBody body, OllamaStreamHandler streamHandler)
|
||||||
throws OllamaBaseException, IOException, InterruptedException {
|
throws OllamaBaseException, IOException, InterruptedException {
|
||||||
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
|
streamObserver = new OllamaGenerateStreamObserver(streamHandler);
|
||||||
return super.callSync(body);
|
return callSync(body);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls the api server on the given host and endpoint suffix asynchronously, aka waiting for the response.
|
||||||
|
*
|
||||||
|
* @param body POST body payload
|
||||||
|
* @return result answer given by the assistant
|
||||||
|
* @throws OllamaBaseException any response code than 200 has been returned
|
||||||
|
* @throws IOException in case the responseStream can not be read
|
||||||
|
* @throws InterruptedException in case the server is not reachable or network issues happen
|
||||||
|
*/
|
||||||
|
public OllamaResult callSync(OllamaRequestBody body) throws OllamaBaseException, IOException, InterruptedException {
|
||||||
|
// Create Request
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
HttpClient httpClient = HttpClient.newHttpClient();
|
||||||
|
URI uri = URI.create(getHost() + getEndpointSuffix());
|
||||||
|
HttpRequest.Builder requestBuilder =
|
||||||
|
getRequestBuilderDefault(uri)
|
||||||
|
.POST(
|
||||||
|
body.getBodyPublisher());
|
||||||
|
HttpRequest request = requestBuilder.build();
|
||||||
|
if (isVerbose()) LOG.info("Asking model: " + body.toString());
|
||||||
|
HttpResponse<InputStream> response =
|
||||||
|
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
|
||||||
|
|
||||||
|
int statusCode = response.statusCode();
|
||||||
|
InputStream responseBodyStream = response.body();
|
||||||
|
StringBuilder responseBuffer = new StringBuilder();
|
||||||
|
try (BufferedReader reader =
|
||||||
|
new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
|
||||||
|
String line;
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
if (statusCode == 404) {
|
||||||
|
LOG.warn("Status code: 404 (Not Found)");
|
||||||
|
OllamaErrorResponse ollamaResponseModel =
|
||||||
|
Utils.getObjectMapper().readValue(line, OllamaErrorResponse.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else if (statusCode == 401) {
|
||||||
|
LOG.warn("Status code: 401 (Unauthorized)");
|
||||||
|
OllamaErrorResponse ollamaResponseModel =
|
||||||
|
Utils.getObjectMapper()
|
||||||
|
.readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponse.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else if (statusCode == 400) {
|
||||||
|
LOG.warn("Status code: 400 (Bad Request)");
|
||||||
|
OllamaErrorResponse ollamaResponseModel = Utils.getObjectMapper().readValue(line,
|
||||||
|
OllamaErrorResponse.class);
|
||||||
|
responseBuffer.append(ollamaResponseModel.getError());
|
||||||
|
} else {
|
||||||
|
boolean finished = parseResponseAndAddToBuffer(line, responseBuffer);
|
||||||
|
if (finished) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (statusCode != 200) {
|
||||||
|
LOG.error("Status code " + statusCode);
|
||||||
|
throw new OllamaBaseException(responseBuffer.toString());
|
||||||
|
} else {
|
||||||
|
long endTime = System.currentTimeMillis();
|
||||||
|
OllamaResult ollamaResult =
|
||||||
|
new OllamaResult(responseBuffer.toString().trim(), endTime - startTime, statusCode);
|
||||||
|
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
|
||||||
|
return ollamaResult;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,7 @@ public class OllamaResult {
|
|||||||
*
|
*
|
||||||
* @return String completion/response text
|
* @return String completion/response text
|
||||||
*/
|
*/
|
||||||
private final String content;
|
private final String response;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* -- GETTER --
|
* -- GETTER --
|
||||||
@ -35,8 +35,8 @@ public class OllamaResult {
|
|||||||
*/
|
*/
|
||||||
private long responseTime = 0;
|
private long responseTime = 0;
|
||||||
|
|
||||||
public OllamaResult(String content, long responseTime, int httpStatusCode) {
|
public OllamaResult(String response, long responseTime, int httpStatusCode) {
|
||||||
this.content = content;
|
this.response = response;
|
||||||
this.responseTime = responseTime;
|
this.responseTime = responseTime;
|
||||||
this.httpStatusCode = httpStatusCode;
|
this.httpStatusCode = httpStatusCode;
|
||||||
}
|
}
|
||||||
|
@ -2,12 +2,9 @@ package io.github.ollama4j.integrationtests;
|
|||||||
|
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||||
|
import io.github.ollama4j.models.chat.*;
|
||||||
import io.github.ollama4j.models.response.ModelDetail;
|
import io.github.ollama4j.models.response.ModelDetail;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
|
||||||
import io.github.ollama4j.models.response.OllamaResult;
|
import io.github.ollama4j.models.response.OllamaResult;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
|
||||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
|
||||||
import io.github.ollama4j.models.chat.OllamaChatResult;
|
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestBuilder;
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||||
import io.github.ollama4j.tools.ToolFunction;
|
import io.github.ollama4j.tools.ToolFunction;
|
||||||
@ -47,6 +44,7 @@ class TestRealAPIs {
|
|||||||
config = new Config();
|
config = new Config();
|
||||||
ollamaAPI = new OllamaAPI(config.getOllamaURL());
|
ollamaAPI = new OllamaAPI(config.getOllamaURL());
|
||||||
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
|
ollamaAPI.setRequestTimeoutSeconds(config.getRequestTimeoutSeconds());
|
||||||
|
ollamaAPI.setVerbose(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -196,7 +194,9 @@ class TestRealAPIs {
|
|||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
assertFalse(chatResult.getResponse().isBlank());
|
assertNotNull(chatResult.getResponse());
|
||||||
|
assertNotNull(chatResult.getResponse().getMessage());
|
||||||
|
assertFalse(chatResult.getResponse().getMessage().getContent().isBlank());
|
||||||
assertEquals(4, chatResult.getChatHistory().size());
|
assertEquals(4, chatResult.getChatHistory().size());
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
fail(e);
|
fail(e);
|
||||||
@ -217,8 +217,10 @@ class TestRealAPIs {
|
|||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
assertFalse(chatResult.getResponse().isBlank());
|
assertNotNull(chatResult.getResponse());
|
||||||
assertTrue(chatResult.getResponse().startsWith("NI"));
|
assertNotNull(chatResult.getResponse().getMessage());
|
||||||
|
assertFalse(chatResult.getResponse().getMessage().getContent().isBlank());
|
||||||
|
assertTrue(chatResult.getResponse().getMessage().getContent().startsWith("NI"));
|
||||||
assertEquals(3, chatResult.getChatHistory().size());
|
assertEquals(3, chatResult.getChatHistory().size());
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
fail(e);
|
fail(e);
|
||||||
@ -267,9 +269,17 @@ class TestRealAPIs {
|
|||||||
.build();
|
.build();
|
||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
System.err.println("Response: " + chatResult);
|
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
assertFalse(chatResult.getResponse().isBlank());
|
assertNotNull(chatResult.getResponse());
|
||||||
|
assertNotNull(chatResult.getResponse().getMessage());
|
||||||
|
assertEquals(OllamaChatMessageRole.ASSISTANT.getRoleName(),chatResult.getResponse().getMessage().getRole().getRoleName());
|
||||||
|
List<OllamaChatToolCalls> toolCalls = chatResult.getResponse().getMessage().getToolCalls();
|
||||||
|
assertEquals(1, toolCalls.size());
|
||||||
|
assertEquals("get-employee-details",toolCalls.get(0).getFunction().getName());
|
||||||
|
assertEquals(1, toolCalls.get(0).getFunction().getArguments().size());
|
||||||
|
String employeeName = toolCalls.get(0).getFunction().getArguments().get("employee-name");
|
||||||
|
assertNotNull(employeeName);
|
||||||
|
assertEquals("Rahul Kumar",employeeName);
|
||||||
assertEquals(2, chatResult.getChatHistory().size());
|
assertEquals(2, chatResult.getChatHistory().size());
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
fail(e);
|
fail(e);
|
||||||
@ -295,7 +305,10 @@ class TestRealAPIs {
|
|||||||
sb.append(substring);
|
sb.append(substring);
|
||||||
});
|
});
|
||||||
assertNotNull(chatResult);
|
assertNotNull(chatResult);
|
||||||
assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
|
assertNotNull(chatResult.getResponse());
|
||||||
|
assertNotNull(chatResult.getResponse().getMessage());
|
||||||
|
assertNotNull(chatResult.getResponse().getMessage().getContent());
|
||||||
|
assertEquals(sb.toString().trim(), chatResult.getResponse().getMessage().getContent().trim());
|
||||||
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
} catch (IOException | OllamaBaseException | InterruptedException e) {
|
||||||
fail(e);
|
fail(e);
|
||||||
}
|
}
|
||||||
@ -309,7 +322,7 @@ class TestRealAPIs {
|
|||||||
OllamaChatRequestBuilder builder =
|
OllamaChatRequestBuilder builder =
|
||||||
OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
||||||
OllamaChatRequest requestModel =
|
OllamaChatRequest requestModel =
|
||||||
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
|
||||||
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
|
List.of(getImageFileFromClasspath("dog-on-a-boat.jpg"))).build();
|
||||||
|
|
||||||
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
|
||||||
@ -338,7 +351,7 @@ class TestRealAPIs {
|
|||||||
testEndpointReachability();
|
testEndpointReachability();
|
||||||
try {
|
try {
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(config.getImageModel());
|
||||||
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",
|
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What's in the picture?",Collections.emptyList(),
|
||||||
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
|
"https://t3.ftcdn.net/jpg/02/96/63/80/360_F_296638053_0gUVA4WVBKceGsIr7LNqRWSnkusi07dq.jpg")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||||||
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
|
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||||
@ -42,7 +43,7 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRequestWithMessageAndImage() {
|
public void testRequestWithMessageAndImage() {
|
||||||
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
|
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt", Collections.emptyList(),
|
||||||
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
|
List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
|
||||||
String jsonRequest = serialize(req);
|
String jsonRequest = serialize(req);
|
||||||
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);
|
assertEqualsAfterUnmarshalling(deserialize(jsonRequest, OllamaChatRequest.class), req);
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
ollama.url=http://localhost:11434
|
ollama.url=http://localhost:11434
|
||||||
ollama.model=qwen:0.5b
|
ollama.model=llama3.2:1b
|
||||||
ollama.model.image=llava
|
ollama.model.image=llava:latest
|
||||||
ollama.request-timeout-seconds=120
|
ollama.request-timeout-seconds=120
|
Loading…
x
Reference in New Issue
Block a user