Refactor OllamaAPI and chat models to support 'thinking' responses

- Introduced a 'thinking' field in OllamaChatMessage to capture intermediate reasoning.
- Updated OllamaChatRequest to include a 'think' parameter for chat requests.
- Modified OllamaChatRequestBuilder to facilitate setting the 'think' parameter.
- Enhanced response handling in OllamaChatStreamObserver and OllamaGenerateStreamObserver to manage 'thinking' content.
- Updated integration tests to validate the new 'thinking' functionality in chat and generation methods.
This commit is contained in:
amithkoujalgi 2025-08-28 12:44:43 +05:30
parent 14642e9856
commit 8d9ee006ee
No known key found for this signature in database
GPG Key ID: E29A37746AF94B70
9 changed files with 182 additions and 207 deletions

View File

@ -137,8 +137,7 @@ public class OllamaAPI {
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = null;
try {
httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
.header("Content-type", "application/json").GET().build();
httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
@ -168,8 +167,7 @@ public class OllamaAPI {
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = null;
try {
httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
.header("Content-type", "application/json").GET().build();
httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
@ -196,8 +194,7 @@ public class OllamaAPI {
public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = this.host + "/api/tags";
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
.header("Content-type", "application/json").GET().build();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
@ -225,12 +222,10 @@ public class OllamaAPI {
* @throws URISyntaxException If there is an error creating the URI for the
* HTTP request.
*/
public List<LibraryModel> listModelsFromLibrary()
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
public List<LibraryModel> listModelsFromLibrary() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = "https://ollama.com/library";
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
.header("Content-type", "application/json").GET().build();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
@ -245,8 +240,7 @@ public class OllamaAPI {
Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type");
Elements popularTags = e.select("div > div > span");
Elements totalTags = e.select("div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type");
Elements lastUpdatedTime = e
.select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)");
Elements lastUpdatedTime = e.select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)");
if (names.first() == null || names.isEmpty()) {
// if name cannot be extracted, skip.
@ -254,12 +248,9 @@ public class OllamaAPI {
}
Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName);
model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse(""));
model.setPopularTags(Optional.of(popularTags)
.map(tags -> tags.stream().map(Element::text).collect(Collectors.toList()))
.orElse(new ArrayList<>()));
model.setPopularTags(Optional.of(popularTags).map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())).orElse(new ArrayList<>()));
model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse(""));
model.setTotalTags(
Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0));
model.setTotalTags(Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0));
model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse(""));
models.add(model);
@ -292,12 +283,10 @@ public class OllamaAPI {
* the HTTP response.
* @throws URISyntaxException if the URI format is incorrect or invalid.
*/
public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName());
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
.header("Content-type", "application/json").GET().build();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
@ -305,8 +294,7 @@ public class OllamaAPI {
List<LibraryModelTag> libraryModelTags = new ArrayList<>();
if (statusCode == 200) {
Document doc = Jsoup.parse(responseString);
Elements tagSections = doc
.select("html > body > main > div > section > div > div > div:nth-child(n+2) > div");
Elements tagSections = doc.select("html > body > main > div > section > div > div > div:nth-child(n+2) > div");
for (Element e : tagSections) {
Elements tags = e.select("div > a > div");
Elements tagsMetas = e.select("div > span");
@ -319,11 +307,8 @@ public class OllamaAPI {
}
libraryModelTag.setName(libraryModel.getName());
Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag);
libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split(""))
.filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse(""));
libraryModelTag
.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split(""))
.filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse(""));
libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("")).filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse(""));
libraryModelTag.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("")).filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse(""));
libraryModelTags.add(libraryModelTag);
}
LibraryModelDetail libraryModelDetail = new LibraryModelDetail();
@ -356,17 +341,11 @@ public class OllamaAPI {
* @throws InterruptedException If the operation is interrupted.
* @throws NoSuchElementException If the model or the tag is not found.
*/
public LibraryModelTag findModelTagFromLibrary(String modelName, String tag)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
List<LibraryModel> libraryModels = this.listModelsFromLibrary();
LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName))
.findFirst().orElseThrow(
() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName)));
LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName)));
LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel);
LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream()
.filter(tagName -> tagName.getTag().equals(tag)).findFirst()
.orElseThrow(() -> new NoSuchElementException(
String.format("Tag '%s' for model '%s' not found", tag, modelName)));
LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Tag '%s' for model '%s' not found", tag, modelName)));
return libraryModelTag;
}
@ -380,8 +359,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed
*/
public void pullModel(String modelName)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
public void pullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
if (numberOfRetriesForModelPull == 0) {
this.doPullModel(modelName);
} else {
@ -395,28 +373,21 @@ public class OllamaAPI {
numberOfRetries++;
}
}
throw new OllamaBaseException(
"Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries");
throw new OllamaBaseException("Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries");
}
}
private void doPullModel(String modelName)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
private void doPullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String url = this.host + "/api/pull";
String jsonData = new ModelRequest(modelName).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url))
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.build();
HttpRequest request = getRequestBuilderDefault(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json").header("Content-type", "application/json").build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode();
InputStream responseBodyStream = response.body();
String responseString = "";
boolean success = false; // Flag to check the pull success.
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
@ -452,8 +423,7 @@ public class OllamaAPI {
public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/version";
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
.header("Content-type", "application/json").GET().build();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
@ -478,8 +448,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed
*/
public void pullModel(LibraryModelTag libraryModelTag)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
public void pullModel(LibraryModelTag libraryModelTag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String tagToPull = String.format("%s:%s", libraryModelTag.getName(), libraryModelTag.getTag());
pullModel(tagToPull);
}
@ -494,12 +463,10 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed
*/
public ModelDetail getModelDetails(String modelName)
throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
String url = this.host + "/api/show";
String jsonData = new ModelRequest(modelName).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
.header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
@ -525,13 +492,10 @@ public class OllamaAPI {
* @throws URISyntaxException if the URI for the request is malformed
*/
@Deprecated
public void createModelWithFilePath(String modelName, String modelFilePath)
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
public void createModelWithFilePath(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create";
String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
@ -565,13 +529,10 @@ public class OllamaAPI {
* @throws URISyntaxException if the URI for the request is malformed
*/
@Deprecated
public void createModelWithModelFileContents(String modelName, String modelFileContents)
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
public void createModelWithModelFileContents(String modelName, String modelFileContents) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create";
String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
@ -598,13 +559,10 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed
*/
public void createModel(CustomModelRequest customModelRequest)
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
public void createModel(CustomModelRequest customModelRequest) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create";
String jsonData = customModelRequest.toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json")
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
@ -631,13 +589,10 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed
*/
public void deleteModel(String modelName, boolean ignoreIfNotPresent)
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
public void deleteModel(String modelName, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/delete";
String jsonData = new ModelRequest(modelName).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url))
.method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.header("Accept", "application/json").header("Content-type", "application/json").build();
HttpRequest request = getRequestBuilderDefault(new URI(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
@ -662,8 +617,7 @@ public class OllamaAPI {
* @deprecated Use {@link #embed(String, List)} instead.
*/
@Deprecated
public List<Double> generateEmbeddings(String model, String prompt)
throws IOException, InterruptedException, OllamaBaseException {
public List<Double> generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException {
return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
}
@ -678,20 +632,17 @@ public class OllamaAPI {
* @deprecated Use {@link #embed(OllamaEmbedRequestModel)} instead.
*/
@Deprecated
public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest)
throws IOException, InterruptedException, OllamaBaseException {
public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException {
URI uri = URI.create(this.host + "/api/embeddings");
String jsonData = modelRequest.toString();
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData));
HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData));
HttpRequest request = requestBuilder.build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseBody = response.body();
if (statusCode == 200) {
OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody,
OllamaEmbeddingResponseModel.class);
OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class);
return embeddingResponse.getEmbedding();
} else {
throw new OllamaBaseException(statusCode + " - " + responseBody);
@ -708,8 +659,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaEmbedResponseModel embed(String model, List<String> inputs)
throws IOException, InterruptedException, OllamaBaseException {
public OllamaEmbedResponseModel embed(String model, List<String> inputs) throws IOException, InterruptedException, OllamaBaseException {
return embed(new OllamaEmbedRequestModel(model, inputs));
}
@ -722,14 +672,12 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest)
throws IOException, InterruptedException, OllamaBaseException {
public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException {
URI uri = URI.create(this.host + "/api/embed");
String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
@ -763,8 +711,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setThink(think);
@ -794,13 +741,14 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options)
throws OllamaBaseException, IOException, InterruptedException {
public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options) throws OllamaBaseException, IOException, InterruptedException {
return generate(model, prompt, raw, think, options, null);
}
/**
* Generates structured output from the specified AI model and prompt.
* <p>
* Note: When formatting is specified, the 'think' parameter is not allowed.
*
* @param model The name or identifier of the AI model to use for generating
* the response.
@ -813,8 +761,8 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request.
* @throws InterruptedException if the operation is interrupted.
*/
public OllamaResult generate(String model, String prompt, Map<String, Object> format)
throws OllamaBaseException, IOException, InterruptedException {
@SuppressWarnings("LoggingSimilarMessage")
public OllamaResult generate(String model, String prompt, Map<String, Object> format) throws OllamaBaseException, IOException, InterruptedException {
URI uri = URI.create(this.host + "/api/generate");
Map<String, Object> requestBody = new HashMap<>();
@ -826,23 +774,30 @@ public class OllamaAPI {
String jsonData = Utils.getObjectMapper().writeValueAsString(requestBody);
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest request = getRequestBuilderDefault(uri)
.header("Accept", "application/json")
.header("Content-type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
.build();
HttpRequest request = getRequestBuilderDefault(uri).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
if (verbose) {
try {
String prettyJson = Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(Utils.getObjectMapper().readValue(jsonData, Object.class));
logger.info("Asking model:\n{}", prettyJson);
} catch (Exception e) {
logger.info("Asking model: {}", jsonData);
}
}
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseBody = response.body();
if (statusCode == 200) {
OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody,
OllamaStructuredResult.class);
OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(),
structuredResult.getResponseTime(), statusCode);
OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, OllamaStructuredResult.class);
OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(), structuredResult.getResponseTime(), statusCode);
if (verbose) {
logger.info("Model response:\n{}", ollamaResult);
}
return ollamaResult;
} else {
if (verbose) {
logger.info("Model response:\n{}", Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseBody));
}
throw new OllamaBaseException(statusCode + " - " + responseBody);
}
}
@ -866,8 +821,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaToolsResult generateWithTools(String model, String prompt, boolean think, Options options)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
public OllamaToolsResult generateWithTools(String model, String prompt, boolean think, Options options) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
boolean raw = true;
OllamaToolsResult toolResult = new OllamaToolsResult();
Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
@ -900,9 +854,7 @@ public class OllamaAPI {
logger.warn("Response from model does not contain any tool calls. Returning the response as is.");
return toolResult;
}
toolFunctionCallSpecs = objectMapper.readValue(
toolsResponse,
objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
toolFunctionCallSpecs = objectMapper.readValue(toolsResponse, objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
}
for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
@ -926,8 +878,7 @@ public class OllamaAPI {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw);
URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(
getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
ollamaAsyncResultStreamer.start();
return ollamaAsyncResultStreamer;
}
@ -952,8 +903,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
List<String> images = new ArrayList<>();
for (File imageFile : imageFiles) {
images.add(encodeFileToBase64(imageFile));
@ -973,8 +923,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options)
throws OllamaBaseException, IOException, InterruptedException {
public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options) throws OllamaBaseException, IOException, InterruptedException {
return generateWithImageFiles(model, prompt, imageFiles, options, null);
}
@ -999,9 +948,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed
*/
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options,
OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
List<String> images = new ArrayList<>();
for (String imageURL : imageURLs) {
images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
@ -1022,8 +969,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed
*/
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
return generateWithImageURLs(model, prompt, imageURLs, options, null);
}
@ -1047,8 +993,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
List<String> encodedImages = new ArrayList<>();
for (byte[] image : images) {
encodedImages.add(encodeByteArrayToBase64(image));
@ -1069,8 +1014,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options)
throws OllamaBaseException, IOException, InterruptedException {
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options) throws OllamaBaseException, IOException, InterruptedException {
return generateWithImages(model, prompt, images, options, null);
}
@ -1094,8 +1038,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails
*/
public OllamaChatResult chat(String model, List<OllamaChatMessage> messages)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
public OllamaChatResult chat(String model, List<OllamaChatMessage> messages) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model);
return chat(builder.withMessages(messages).build());
}
@ -1119,8 +1062,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails
*/
public OllamaChatResult chat(OllamaChatRequest request)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
public OllamaChatResult chat(OllamaChatRequest request) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
return chat(request, null);
}
@ -1146,8 +1088,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails
*/
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
}
@ -1170,15 +1111,12 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds,
verbose);
public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, verbose);
OllamaChatResult result;
// add all registered tools to Request
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt)
.collect(Collectors.toList()));
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
if (tokenHandler != null) {
request.setStream(true);
@ -1199,8 +1137,7 @@ public class OllamaAPI {
}
Map<String, Object> arguments = toolCall.getFunction().getArguments();
Object res = toolFunction.apply(arguments);
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL,
"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]"));
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]"));
}
if (tokenHandler != null) {
@ -1276,8 +1213,8 @@ public class OllamaAPI {
for (Class<?> provider : providers) {
registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
}
} catch (InstantiationException | NoSuchMethodException | IllegalAccessException
| InvocationTargetException e) {
} catch (InstantiationException | NoSuchMethodException | IllegalAccessException |
InvocationTargetException e) {
throw new RuntimeException(e);
}
}
@ -1317,22 +1254,12 @@ public class OllamaAPI {
}
String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
methodParams.put(propName, propType);
propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType)
.description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build());
propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType).description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build());
}
final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
List<String> reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired())
.map(Map.Entry::getKey).collect(Collectors.toList());
List<String> reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()).map(Map.Entry::getKey).collect(Collectors.toList());
Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName)
.functionDescription(operationDesc)
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName)
.description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters
.builder().type("object").properties(params).required(reqProps).build())
.build())
.build())
.build();
Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName).functionDescription(operationDesc).toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName).description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(params).required(reqProps).build()).build()).build()).build();
ReflectionalToolFunction reflectionalToolFunction = new ReflectionalToolFunction(object, m, methodParams);
toolSpecification.setToolFunction(reflectionalToolFunction);
@ -1413,10 +1340,8 @@ public class OllamaAPI {
* process.
* @throws InterruptedException if the thread is interrupted during the request.
*/
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds,
verbose);
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, verbose);
OllamaResult result;
if (streamHandler != null) {
ollamaRequestModel.setStream(true);
@ -1434,8 +1359,7 @@ public class OllamaAPI {
* @return HttpRequest.Builder
*/
private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(requestTimeoutSeconds));
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds));
if (isBasicAuthCredentialsSet()) {
requestBuilder.header("Authorization", auth.getAuthHeaderValue());
}
@ -1460,8 +1384,7 @@ public class OllamaAPI {
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
}
if (function == null) {
throw new ToolNotFoundException(
"No such tool: " + methodName + ". Please register the tool before invoking it.");
throw new ToolNotFoundException("No such tool: " + methodName + ". Please register the tool before invoking it.");
}
return function.apply(arguments);
} catch (Exception e) {

View File

@ -35,6 +35,8 @@ public class OllamaChatMessage {
@NonNull
private String content;
private String thinking;
private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
@JsonSerialize(using = FileToBase64Serializer.class)

View File

@ -13,31 +13,35 @@ import lombok.Setter;
* Defines a Request to use against the ollama /api/chat endpoint.
*
* @see <a href=
* "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate
* Chat Completion</a>
* "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate
* Chat Completion</a>
*/
@Getter
@Setter
public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequestBody {
private List<OllamaChatMessage> messages;
private List<OllamaChatMessage> messages;
private List<Tools.PromptFuncDefinition> tools;
private List<Tools.PromptFuncDefinition> tools;
public OllamaChatRequest() {}
private boolean think;
public OllamaChatRequest(String model, List<OllamaChatMessage> messages) {
this.model = model;
this.messages = messages;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof OllamaChatRequest)) {
return false;
public OllamaChatRequest() {
}
return this.toString().equals(o.toString());
}
public OllamaChatRequest(String model, boolean think, List<OllamaChatMessage> messages) {
this.model = model;
this.messages = messages;
this.think = think;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof OllamaChatRequest)) {
return false;
}
return this.toString().equals(o.toString());
}
}

View File

@ -22,7 +22,7 @@ public class OllamaChatRequestBuilder {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class);
private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages) {
request = new OllamaChatRequest(model, messages);
request = new OllamaChatRequest(model, false, messages);
}
private OllamaChatRequest request;
@ -36,7 +36,7 @@ public class OllamaChatRequestBuilder {
}
public void reset() {
request = new OllamaChatRequest(request.getModel(), new ArrayList<>());
request = new OllamaChatRequest(request.getModel(), request.isThink(), new ArrayList<>());
}
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content) {
@ -45,7 +45,7 @@ public class OllamaChatRequestBuilder {
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls) {
List<OllamaChatMessage> messages = this.request.getMessages();
messages.add(new OllamaChatMessage(role, content, toolCalls, null));
messages.add(new OllamaChatMessage(role, content, null, toolCalls, null));
return this;
}
@ -61,7 +61,7 @@ public class OllamaChatRequestBuilder {
}
}).collect(Collectors.toList());
messages.add(new OllamaChatMessage(role, content, toolCalls, binaryImages));
messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages));
return this;
}
@ -81,7 +81,7 @@ public class OllamaChatRequestBuilder {
}
}
messages.add(new OllamaChatMessage(role, content, toolCalls, binaryImages));
messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages));
return this;
}
@ -114,4 +114,8 @@ public class OllamaChatRequestBuilder {
return this;
}
public OllamaChatRequestBuilder withThinking(boolean think) {
this.request.setThink(think);
return this;
}
}

View File

@ -11,9 +11,22 @@ public class OllamaChatStreamObserver implements OllamaTokenHandler {
@Override
public void accept(OllamaChatResponseModel token) {
if (streamHandler != null) {
message += token.getMessage().getContent();
streamHandler.accept(message);
if (streamHandler == null || token == null || token.getMessage() == null) {
return;
}
String content = token.getMessage().getContent();
String thinking = token.getMessage().getThinking();
boolean hasContent = !content.isEmpty();
boolean hasThinking = thinking != null && !thinking.isEmpty();
if (hasThinking && !hasContent) {
message += thinking;
} else {
message += content;
}
streamHandler.accept(message);
}
}

View File

@ -24,8 +24,8 @@ public class OllamaGenerateStreamObserver {
String response = currentResponsePart.getResponse();
String thinking = currentResponsePart.getThinking();
boolean hasResponse = response != null && !response.trim().isEmpty();
boolean hasThinking = thinking != null && !thinking.trim().isEmpty();
boolean hasResponse = response != null && !response.isEmpty();
boolean hasThinking = thinking != null && !thinking.isEmpty();
if (!hasResponse && hasThinking) {
message = message + thinking;

View File

@ -58,7 +58,12 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
// thus, we null check the message and hope that the next streamed response has some message content again
OllamaChatMessage message = ollamaResponseModel.getMessage();
if (message != null) {
responseBuffer.append(message.getContent());
if (message.getThinking() != null) {
thinkingBuffer.append(message.getThinking());
}
else {
responseBuffer.append(message.getContent());
}
if (tokenHandler != null) {
tokenHandler.accept(ollamaResponseModel);
}
@ -85,7 +90,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
.POST(
body.getBodyPublisher());
HttpRequest request = requestBuilder.build();
if (isVerbose()) LOG.info("Asking model: " + body);
if (isVerbose()) LOG.info("Asking model: {}", body);
HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
@ -129,6 +134,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
}
if (finished && body.stream) {
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
ollamaChatResponseModel.getMessage().setThinking(thinkingBuffer.toString());
break;
}
}

View File

@ -125,7 +125,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
} else {
long endTime = System.currentTimeMillis();
OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), thinkingBuffer.toString().trim(), endTime - startTime, statusCode);
new OllamaResult(responseBuffer.toString(), thinkingBuffer.toString(), endTime - startTime, statusCode);
if (isVerbose()) LOG.info("Model response: " + ollamaResult);
return ollamaResult;
}

View File

@ -53,6 +53,7 @@ public class OllamaAPIIntegrationTest {
private static final String CHAT_MODEL_LLAMA3 = "llama3";
private static final String IMAGE_MODEL_LLAVA = "llava";
private static final String THINKING_MODEL_GPT_OSS = "gpt-oss:20b";
private static final String THINKING_MODEL_QWEN = "qwen3:0.6b";
@BeforeAll
public static void setUp() {
@ -220,7 +221,7 @@ public class OllamaAPIIntegrationTest {
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
assertEquals(sb.toString(), result.getResponse());
}
@Test
@ -441,29 +442,51 @@ public class OllamaAPIIntegrationTest {
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getContent());
}
@Test
@Order(15)
void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
api.pullModel(CHAT_MODEL_QWEN_SMALL);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL);
api.pullModel(THINKING_MODEL_QWEN);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_MODEL_QWEN);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?").build();
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
LOG.info(substring);
String substring = s.substring(sb.toString().length());
sb.append(substring);
});
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim());
assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getContent());
}
@Test
@Order(15)
void testChatWithThinkingAndStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
api.pullModel(THINKING_MODEL_QWEN);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_MODEL_QWEN);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?")
.withThinking(true)
.withKeepAlive("0m")
.build();
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length());
sb.append(substring);
});
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getThinking() + chatResult.getResponseModel().getMessage().getContent());
}
@Test
@ -503,14 +526,14 @@ public class OllamaAPIIntegrationTest {
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length());
String substring = s.substring(sb.toString().length());
LOG.info(substring);
sb.append(substring);
});
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim());
assertEquals(sb.toString(), result.getResponse());
}
@Test
@ -532,13 +555,13 @@ public class OllamaAPIIntegrationTest {
@Test
@Order(20)
void testGenerateWithThinkingAndStreamHandler() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(THINKING_MODEL_GPT_OSS);
api.pullModel(THINKING_MODEL_QWEN);
boolean raw = false;
boolean thinking = true;
StringBuffer sb = new StringBuffer();
OllamaResult result = api.generate(THINKING_MODEL_GPT_OSS, "Who are you?", raw, thinking, new OptionsBuilder().build(), (s) -> {
OllamaResult result = api.generate(THINKING_MODEL_QWEN, "Who are you?", raw, thinking, new OptionsBuilder().build(), (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length());
sb.append(substring);
@ -548,7 +571,7 @@ public class OllamaAPIIntegrationTest {
assertFalse(result.getResponse().isEmpty());
assertNotNull(result.getThinking());
assertFalse(result.getThinking().isEmpty());
assertEquals(sb.toString().trim(), result.getThinking().trim() + result.getResponse().trim());
assertEquals(sb.toString(), result.getThinking() + result.getResponse());
}
private File getImageFileFromClasspath(String fileName) {