From 8d9ee006ee1c902dcd89de762e55b360a139bb0b Mon Sep 17 00:00:00 2001 From: amithkoujalgi Date: Thu, 28 Aug 2025 12:44:43 +0530 Subject: [PATCH] Refactor OllamaAPI and chat models to support 'thinking' responses - Introduced a 'thinking' field in OllamaChatMessage to capture intermediate reasoning. - Updated OllamaChatRequest to include a 'think' parameter for chat requests. - Modified OllamaChatRequestBuilder to facilitate setting the 'think' parameter. - Enhanced response handling in OllamaChatStreamObserver and OllamaGenerateStreamObserver to manage 'thinking' content. - Updated integration tests to validate the new 'thinking' functionality in chat and generation methods. --- .../java/io/github/ollama4j/OllamaAPI.java | 253 ++++++------------ .../models/chat/OllamaChatMessage.java | 2 + .../models/chat/OllamaChatRequest.java | 36 +-- .../models/chat/OllamaChatRequestBuilder.java | 14 +- .../models/chat/OllamaChatStreamObserver.java | 19 +- .../OllamaGenerateStreamObserver.java | 4 +- .../request/OllamaChatEndpointCaller.java | 10 +- .../request/OllamaGenerateEndpointCaller.java | 2 +- .../OllamaAPIIntegrationTest.java | 49 +++- 9 files changed, 182 insertions(+), 207 deletions(-) diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index be91603..0fcc2a0 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -137,8 +137,7 @@ public class OllamaAPI { HttpClient httpClient = HttpClient.newHttpClient(); HttpRequest httpRequest = null; try { - httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").GET().build(); + httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -168,8 +167,7 @@ public class OllamaAPI { HttpClient httpClient = HttpClient.newHttpClient(); HttpRequest httpRequest = null; try { - httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").GET().build(); + httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -196,8 +194,7 @@ public class OllamaAPI { public List listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = this.host + "/api/tags"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -225,12 +222,10 @@ public class OllamaAPI { * @throws URISyntaxException If there is an error creating the URI for the * HTTP request. */ - public List listModelsFromLibrary() - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + public List listModelsFromLibrary() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = "https://ollama.com/library"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -245,8 +240,7 @@ public class OllamaAPI { Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type"); Elements popularTags = e.select("div > div > span"); Elements totalTags = e.select("div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type"); - Elements lastUpdatedTime = e - .select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)"); + Elements lastUpdatedTime = e.select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)"); if (names.first() == null || names.isEmpty()) { // if name cannot be extracted, skip. @@ -254,12 +248,9 @@ public class OllamaAPI { } Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName); model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse("")); - model.setPopularTags(Optional.of(popularTags) - .map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())) - .orElse(new ArrayList<>())); + model.setPopularTags(Optional.of(popularTags).map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())).orElse(new ArrayList<>())); model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse("")); - model.setTotalTags( - Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0)); + model.setTotalTags(Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0)); model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse("")); models.add(model); @@ -292,12 +283,10 @@ public class OllamaAPI { * the HTTP response. * @throws URISyntaxException if the URI format is incorrect or invalid. */ - public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel) - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName()); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -305,8 +294,7 @@ public class OllamaAPI { List libraryModelTags = new ArrayList<>(); if (statusCode == 200) { Document doc = Jsoup.parse(responseString); - Elements tagSections = doc - .select("html > body > main > div > section > div > div > div:nth-child(n+2) > div"); + Elements tagSections = doc.select("html > body > main > div > section > div > div > div:nth-child(n+2) > div"); for (Element e : tagSections) { Elements tags = e.select("div > a > div"); Elements tagsMetas = e.select("div > span"); @@ -319,11 +307,8 @@ public class OllamaAPI { } libraryModelTag.setName(libraryModel.getName()); Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag); - libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")) - .filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse("")); - libraryModelTag - .setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")) - .filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse("")); + libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")).filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse("")); + libraryModelTag.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")).filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse("")); libraryModelTags.add(libraryModelTag); } LibraryModelDetail libraryModelDetail = new LibraryModelDetail(); @@ -356,17 +341,11 @@ public class OllamaAPI { * @throws InterruptedException If the operation is interrupted. * @throws NoSuchElementException If the model or the tag is not found. */ - public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { List libraryModels = this.listModelsFromLibrary(); - LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)) - .findFirst().orElseThrow( - () -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName))); + LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName))); LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel); - LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream() - .filter(tagName -> tagName.getTag().equals(tag)).findFirst() - .orElseThrow(() -> new NoSuchElementException( - String.format("Tag '%s' for model '%s' not found", tag, modelName))); + LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Tag '%s' for model '%s' not found", tag, modelName))); return libraryModelTag; } @@ -380,8 +359,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public void pullModel(String modelName) - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + public void pullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { if (numberOfRetriesForModelPull == 0) { this.doPullModel(modelName); } else { @@ -395,28 +373,21 @@ public class OllamaAPI { numberOfRetries++; } } - throw new OllamaBaseException( - "Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries"); + throw new OllamaBaseException("Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries"); } } - private void doPullModel(String modelName) - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + private void doPullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String url = this.host + "/api/pull"; String jsonData = new ModelRequest(modelName).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)) - .POST(HttpRequest.BodyPublishers.ofString(jsonData)) - .header("Accept", "application/json") - .header("Content-type", "application/json") - .build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json").header("Content-type", "application/json").build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); InputStream responseBodyStream = response.body(); String responseString = ""; boolean success = false; // Flag to check the pull success. - try (BufferedReader reader = new BufferedReader( - new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class); @@ -452,8 +423,7 @@ public class OllamaAPI { public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException { String url = this.host + "/api/version"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -478,8 +448,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public void pullModel(LibraryModelTag libraryModelTag) - throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + public void pullModel(LibraryModelTag libraryModelTag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String tagToPull = String.format("%s:%s", libraryModelTag.getName(), libraryModelTag.getTag()); pullModel(tagToPull); } @@ -494,12 +463,10 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public ModelDetail getModelDetails(String modelName) - throws IOException, OllamaBaseException, InterruptedException, URISyntaxException { + public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException, URISyntaxException { String url = this.host + "/api/show"; String jsonData = new ModelRequest(modelName).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -525,13 +492,10 @@ public class OllamaAPI { * @throws URISyntaxException if the URI for the request is malformed */ @Deprecated - public void createModelWithFilePath(String modelName, String modelFilePath) - throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { + public void createModelWithFilePath(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-Type", "application/json") - .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -565,13 +529,10 @@ public class OllamaAPI { * @throws URISyntaxException if the URI for the request is malformed */ @Deprecated - public void createModelWithModelFileContents(String modelName, String modelFileContents) - throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { + public void createModelWithModelFileContents(String modelName, String modelFileContents) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-Type", "application/json") - .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -598,13 +559,10 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public void createModel(CustomModelRequest customModelRequest) - throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { + public void createModel(CustomModelRequest customModelRequest) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = customModelRequest.toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") - .header("Content-Type", "application/json") - .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -631,13 +589,10 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public void deleteModel(String modelName, boolean ignoreIfNotPresent) - throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { + public void deleteModel(String modelName, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/delete"; String jsonData = new ModelRequest(modelName).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)) - .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) - .header("Accept", "application/json").header("Content-type", "application/json").build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -662,8 +617,7 @@ public class OllamaAPI { * @deprecated Use {@link #embed(String, List)} instead. */ @Deprecated - public List generateEmbeddings(String model, String prompt) - throws IOException, InterruptedException, OllamaBaseException { + public List generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException { return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt)); } @@ -678,20 +632,17 @@ public class OllamaAPI { * @deprecated Use {@link #embed(OllamaEmbedRequestModel)} instead. */ @Deprecated - public List generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) - throws IOException, InterruptedException, OllamaBaseException { + public List generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException { URI uri = URI.create(this.host + "/api/embeddings"); String jsonData = modelRequest.toString(); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json") - .POST(HttpRequest.BodyPublishers.ofString(jsonData)); + HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)); HttpRequest request = requestBuilder.build(); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseBody = response.body(); if (statusCode == 200) { - OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody, - OllamaEmbeddingResponseModel.class); + OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class); return embeddingResponse.getEmbedding(); } else { throw new OllamaBaseException(statusCode + " - " + responseBody); @@ -708,8 +659,7 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaEmbedResponseModel embed(String model, List inputs) - throws IOException, InterruptedException, OllamaBaseException { + public OllamaEmbedResponseModel embed(String model, List inputs) throws IOException, InterruptedException, OllamaBaseException { return embed(new OllamaEmbedRequestModel(model, inputs)); } @@ -722,14 +672,12 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) - throws IOException, InterruptedException, OllamaBaseException { + public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException { URI uri = URI.create(this.host + "/api/embed"); String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json") - .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -763,8 +711,7 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options, - OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); ollamaRequestModel.setRaw(raw); ollamaRequestModel.setThink(think); @@ -794,13 +741,14 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options) - throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options) throws OllamaBaseException, IOException, InterruptedException { return generate(model, prompt, raw, think, options, null); } /** * Generates structured output from the specified AI model and prompt. + *

+ * Note: When formatting is specified, the 'think' parameter is not allowed. * * @param model The name or identifier of the AI model to use for generating * the response. @@ -813,8 +761,8 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request. * @throws InterruptedException if the operation is interrupted. */ - public OllamaResult generate(String model, String prompt, Map format) - throws OllamaBaseException, IOException, InterruptedException { + @SuppressWarnings("LoggingSimilarMessage") + public OllamaResult generate(String model, String prompt, Map format) throws OllamaBaseException, IOException, InterruptedException { URI uri = URI.create(this.host + "/api/generate"); Map requestBody = new HashMap<>(); @@ -826,23 +774,30 @@ public class OllamaAPI { String jsonData = Utils.getObjectMapper().writeValueAsString(requestBody); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest request = getRequestBuilderDefault(uri) - .header("Accept", "application/json") - .header("Content-type", "application/json") - .POST(HttpRequest.BodyPublishers.ofString(jsonData)) - .build(); + HttpRequest request = getRequestBuilderDefault(uri).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + if (verbose) { + try { + String prettyJson = Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(Utils.getObjectMapper().readValue(jsonData, Object.class)); + logger.info("Asking model:\n{}", prettyJson); + } catch (Exception e) { + logger.info("Asking model: {}", jsonData); + } + } HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseBody = response.body(); - if (statusCode == 200) { - OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, - OllamaStructuredResult.class); - OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(), - structuredResult.getResponseTime(), statusCode); + OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, OllamaStructuredResult.class); + OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(), structuredResult.getResponseTime(), statusCode); + if (verbose) { + logger.info("Model response:\n{}", ollamaResult); + } return ollamaResult; } else { + if (verbose) { + logger.info("Model response:\n{}", Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseBody)); + } throw new OllamaBaseException(statusCode + " - " + responseBody); } } @@ -866,8 +821,7 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaToolsResult generateWithTools(String model, String prompt, boolean think, Options options) - throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + public OllamaToolsResult generateWithTools(String model, String prompt, boolean think, Options options) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { boolean raw = true; OllamaToolsResult toolResult = new OllamaToolsResult(); Map toolResults = new HashMap<>(); @@ -900,9 +854,7 @@ public class OllamaAPI { logger.warn("Response from model does not contain any tool calls. Returning the response as is."); return toolResult; } - toolFunctionCallSpecs = objectMapper.readValue( - toolsResponse, - objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class)); + toolFunctionCallSpecs = objectMapper.readValue(toolsResponse, objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class)); } for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) { toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec)); @@ -926,8 +878,7 @@ public class OllamaAPI { OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); ollamaRequestModel.setRaw(raw); URI uri = URI.create(this.host + "/api/generate"); - OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer( - getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds); + OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds); ollamaAsyncResultStreamer.start(); return ollamaAsyncResultStreamer; } @@ -952,8 +903,7 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options, - OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { List images = new ArrayList<>(); for (File imageFile : imageFiles) { images.add(encodeFileToBase64(imageFile)); @@ -973,8 +923,7 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options) - throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options) throws OllamaBaseException, IOException, InterruptedException { return generateWithImageFiles(model, prompt, imageFiles, options, null); } @@ -999,9 +948,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options, - OllamaStreamHandler streamHandler) - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { List images = new ArrayList<>(); for (String imageURL : imageURLs) { images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL))); @@ -1022,8 +969,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options) - throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { return generateWithImageURLs(model, prompt, imageURLs, options, null); } @@ -1047,8 +993,7 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImages(String model, String prompt, List images, Options options, - OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImages(String model, String prompt, List images, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { List encodedImages = new ArrayList<>(); for (byte[] image : images) { encodedImages.add(encodeByteArrayToBase64(image)); @@ -1069,8 +1014,7 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImages(String model, String prompt, List images, Options options) - throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImages(String model, String prompt, List images, Options options) throws OllamaBaseException, IOException, InterruptedException { return generateWithImages(model, prompt, images, options, null); } @@ -1094,8 +1038,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ - public OllamaChatResult chat(String model, List messages) - throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + public OllamaChatResult chat(String model, List messages) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model); return chat(builder.withMessages(messages).build()); } @@ -1119,8 +1062,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ - public OllamaChatResult chat(OllamaChatRequest request) - throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + public OllamaChatResult chat(OllamaChatRequest request) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { return chat(request, null); } @@ -1146,8 +1088,7 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ - public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) - throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { return chatStreaming(request, new OllamaChatStreamObserver(streamHandler)); } @@ -1170,15 +1111,12 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) - throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { - OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, - verbose); + public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, verbose); OllamaChatResult result; // add all registered tools to Request - request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt) - .collect(Collectors.toList())); + request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList())); if (tokenHandler != null) { request.setStream(true); @@ -1199,8 +1137,7 @@ public class OllamaAPI { } Map arguments = toolCall.getFunction().getArguments(); Object res = toolFunction.apply(arguments); - request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, - "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]")); + request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]")); } if (tokenHandler != null) { @@ -1276,8 +1213,8 @@ public class OllamaAPI { for (Class provider : providers) { registerAnnotatedTools(provider.getDeclaredConstructor().newInstance()); } - } catch (InstantiationException | NoSuchMethodException | IllegalAccessException - | InvocationTargetException e) { + } catch (InstantiationException | NoSuchMethodException | IllegalAccessException | + InvocationTargetException e) { throw new RuntimeException(e); } } @@ -1317,22 +1254,12 @@ public class OllamaAPI { } String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName(); methodParams.put(propName, propType); - propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType) - .description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build()); + propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType).description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build()); } final Map params = propsBuilder.build(); - List reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()) - .map(Map.Entry::getKey).collect(Collectors.toList()); + List reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()).map(Map.Entry::getKey).collect(Collectors.toList()); - Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName) - .functionDescription(operationDesc) - .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") - .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName) - .description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters - .builder().type("object").properties(params).required(reqProps).build()) - .build()) - .build()) - .build(); + Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName).functionDescription(operationDesc).toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName).description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(params).required(reqProps).build()).build()).build()).build(); ReflectionalToolFunction reflectionalToolFunction = new ReflectionalToolFunction(object, m, methodParams); toolSpecification.setToolFunction(reflectionalToolFunction); @@ -1413,10 +1340,8 @@ public class OllamaAPI { * process. * @throws InterruptedException if the thread is interrupted during the request. */ - private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, - OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { - OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, - verbose); + private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, verbose); OllamaResult result; if (streamHandler != null) { ollamaRequestModel.setStream(true); @@ -1434,8 +1359,7 @@ public class OllamaAPI { * @return HttpRequest.Builder */ private HttpRequest.Builder getRequestBuilderDefault(URI uri) { - HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json") - .timeout(Duration.ofSeconds(requestTimeoutSeconds)); + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds)); if (isBasicAuthCredentialsSet()) { requestBuilder.header("Authorization", auth.getAuthHeaderValue()); } @@ -1460,8 +1384,7 @@ public class OllamaAPI { logger.debug("Invoking function {} with arguments {}", methodName, arguments); } if (function == null) { - throw new ToolNotFoundException( - "No such tool: " + methodName + ". Please register the tool before invoking it."); + throw new ToolNotFoundException("No such tool: " + methodName + ". Please register the tool before invoking it."); } return function.apply(arguments); } catch (Exception e) { diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java index 86b7726..d8e72de 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessage.java @@ -35,6 +35,8 @@ public class OllamaChatMessage { @NonNull private String content; + private String thinking; + private @JsonProperty("tool_calls") List toolCalls; @JsonSerialize(using = FileToBase64Serializer.class) diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java index 5d19703..cf3c0ab 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequest.java @@ -13,31 +13,35 @@ import lombok.Setter; * Defines a Request to use against the ollama /api/chat endpoint. * * @see Generate - * Chat Completion + * "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate + * Chat Completion */ @Getter @Setter public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequestBody { - private List messages; + private List messages; - private List tools; + private List tools; - public OllamaChatRequest() {} + private boolean think; - public OllamaChatRequest(String model, List messages) { - this.model = model; - this.messages = messages; - } - - @Override - public boolean equals(Object o) { - if (!(o instanceof OllamaChatRequest)) { - return false; + public OllamaChatRequest() { } - return this.toString().equals(o.toString()); - } + public OllamaChatRequest(String model, boolean think, List messages) { + this.model = model; + this.messages = messages; + this.think = think; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof OllamaChatRequest)) { + return false; + } + + return this.toString().equals(o.toString()); + } } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java index 47d6eb5..4a9caf9 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatRequestBuilder.java @@ -22,7 +22,7 @@ public class OllamaChatRequestBuilder { private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class); private OllamaChatRequestBuilder(String model, List messages) { - request = new OllamaChatRequest(model, messages); + request = new OllamaChatRequest(model, false, messages); } private OllamaChatRequest request; @@ -36,7 +36,7 @@ public class OllamaChatRequestBuilder { } public void reset() { - request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); + request = new OllamaChatRequest(request.getModel(), request.isThink(), new ArrayList<>()); } public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content) { @@ -45,7 +45,7 @@ public class OllamaChatRequestBuilder { public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List toolCalls) { List messages = this.request.getMessages(); - messages.add(new OllamaChatMessage(role, content, toolCalls, null)); + messages.add(new OllamaChatMessage(role, content, null, toolCalls, null)); return this; } @@ -61,7 +61,7 @@ public class OllamaChatRequestBuilder { } }).collect(Collectors.toList()); - messages.add(new OllamaChatMessage(role, content, toolCalls, binaryImages)); + messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages)); return this; } @@ -81,7 +81,7 @@ public class OllamaChatRequestBuilder { } } - messages.add(new OllamaChatMessage(role, content, toolCalls, binaryImages)); + messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages)); return this; } @@ -114,4 +114,8 @@ public class OllamaChatRequestBuilder { return this; } + public OllamaChatRequestBuilder withThinking(boolean think) { + this.request.setThink(think); + return this; + } } diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java index af181da..52291b9 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatStreamObserver.java @@ -11,9 +11,22 @@ public class OllamaChatStreamObserver implements OllamaTokenHandler { @Override public void accept(OllamaChatResponseModel token) { - if (streamHandler != null) { - message += token.getMessage().getContent(); - streamHandler.accept(message); + if (streamHandler == null || token == null || token.getMessage() == null) { + return; } + + String content = token.getMessage().getContent(); + String thinking = token.getMessage().getThinking(); + + boolean hasContent = !content.isEmpty(); + boolean hasThinking = thinking != null && !thinking.isEmpty(); + + if (hasThinking && !hasContent) { + message += thinking; + } else { + message += content; + } + + streamHandler.accept(message); } } diff --git a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java index a13a0a0..a449894 100644 --- a/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java +++ b/src/main/java/io/github/ollama4j/models/generate/OllamaGenerateStreamObserver.java @@ -24,8 +24,8 @@ public class OllamaGenerateStreamObserver { String response = currentResponsePart.getResponse(); String thinking = currentResponsePart.getThinking(); - boolean hasResponse = response != null && !response.trim().isEmpty(); - boolean hasThinking = thinking != null && !thinking.trim().isEmpty(); + boolean hasResponse = response != null && !response.isEmpty(); + boolean hasThinking = thinking != null && !thinking.isEmpty(); if (!hasResponse && hasThinking) { message = message + thinking; diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java index 94db829..65db860 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaChatEndpointCaller.java @@ -58,7 +58,12 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { // thus, we null check the message and hope that the next streamed response has some message content again OllamaChatMessage message = ollamaResponseModel.getMessage(); if (message != null) { - responseBuffer.append(message.getContent()); + if (message.getThinking() != null) { + thinkingBuffer.append(message.getThinking()); + } + else { + responseBuffer.append(message.getContent()); + } if (tokenHandler != null) { tokenHandler.accept(ollamaResponseModel); } @@ -85,7 +90,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { .POST( body.getBodyPublisher()); HttpRequest request = requestBuilder.build(); - if (isVerbose()) LOG.info("Asking model: " + body); + if (isVerbose()) LOG.info("Asking model: {}", body); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); @@ -129,6 +134,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller { } if (finished && body.stream) { ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString()); + ollamaChatResponseModel.getMessage().setThinking(thinkingBuffer.toString()); break; } } diff --git a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java index 5e7c1f4..55d6fdf 100644 --- a/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java +++ b/src/main/java/io/github/ollama4j/models/request/OllamaGenerateEndpointCaller.java @@ -125,7 +125,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller { } else { long endTime = System.currentTimeMillis(); OllamaResult ollamaResult = - new OllamaResult(responseBuffer.toString().trim(), thinkingBuffer.toString().trim(), endTime - startTime, statusCode); + new OllamaResult(responseBuffer.toString(), thinkingBuffer.toString(), endTime - startTime, statusCode); if (isVerbose()) LOG.info("Model response: " + ollamaResult); return ollamaResult; } diff --git a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java index f81b45b..6186099 100644 --- a/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java +++ b/src/test/java/io/github/ollama4j/integrationtests/OllamaAPIIntegrationTest.java @@ -53,6 +53,7 @@ public class OllamaAPIIntegrationTest { private static final String CHAT_MODEL_LLAMA3 = "llama3"; private static final String IMAGE_MODEL_LLAVA = "llava"; private static final String THINKING_MODEL_GPT_OSS = "gpt-oss:20b"; + private static final String THINKING_MODEL_QWEN = "qwen3:0.6b"; @BeforeAll public static void setUp() { @@ -220,7 +221,7 @@ public class OllamaAPIIntegrationTest { assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); - assertEquals(sb.toString().trim(), result.getResponse().trim()); + assertEquals(sb.toString(), result.getResponse()); } @Test @@ -441,29 +442,51 @@ public class OllamaAPIIntegrationTest { assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage().getContent()); - assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); + assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getContent()); } @Test @Order(15) void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { - api.pullModel(CHAT_MODEL_QWEN_SMALL); - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL); + api.pullModel(THINKING_MODEL_QWEN); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_MODEL_QWEN); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?").build(); - StringBuffer sb = new StringBuffer(); OllamaChatResult chatResult = api.chat(requestModel, (s) -> { LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); - LOG.info(substring); + String substring = s.substring(sb.toString().length()); sb.append(substring); }); assertNotNull(chatResult); assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage().getContent()); - assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); + assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getContent()); + } + + @Test + @Order(15) + void testChatWithThinkingAndStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { + api.pullModel(THINKING_MODEL_QWEN); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_MODEL_QWEN); + OllamaChatRequest requestModel = builder + .withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?") + .withThinking(true) + .withKeepAlive("0m") + .build(); + StringBuffer sb = new StringBuffer(); + + OllamaChatResult chatResult = api.chat(requestModel, (s) -> { + LOG.info(s); + String substring = s.substring(sb.toString().length()); + sb.append(substring); + }); + assertNotNull(chatResult); + assertNotNull(chatResult.getResponseModel()); + assertNotNull(chatResult.getResponseModel().getMessage()); + assertNotNull(chatResult.getResponseModel().getMessage().getContent()); + assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getThinking() + chatResult.getResponseModel().getMessage().getContent()); } @Test @@ -503,14 +526,14 @@ public class OllamaAPIIntegrationTest { OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> { LOG.info(s); - String substring = s.substring(sb.toString().length(), s.length()); + String substring = s.substring(sb.toString().length()); LOG.info(substring); sb.append(substring); }); assertNotNull(result); assertNotNull(result.getResponse()); assertFalse(result.getResponse().isEmpty()); - assertEquals(sb.toString().trim(), result.getResponse().trim()); + assertEquals(sb.toString(), result.getResponse()); } @Test @@ -532,13 +555,13 @@ public class OllamaAPIIntegrationTest { @Test @Order(20) void testGenerateWithThinkingAndStreamHandler() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { - api.pullModel(THINKING_MODEL_GPT_OSS); + api.pullModel(THINKING_MODEL_QWEN); boolean raw = false; boolean thinking = true; StringBuffer sb = new StringBuffer(); - OllamaResult result = api.generate(THINKING_MODEL_GPT_OSS, "Who are you?", raw, thinking, new OptionsBuilder().build(), (s) -> { + OllamaResult result = api.generate(THINKING_MODEL_QWEN, "Who are you?", raw, thinking, new OptionsBuilder().build(), (s) -> { LOG.info(s); String substring = s.substring(sb.toString().length()); sb.append(substring); @@ -548,7 +571,7 @@ public class OllamaAPIIntegrationTest { assertFalse(result.getResponse().isEmpty()); assertNotNull(result.getThinking()); assertFalse(result.getThinking().isEmpty()); - assertEquals(sb.toString().trim(), result.getThinking().trim() + result.getResponse().trim()); + assertEquals(sb.toString(), result.getThinking() + result.getResponse()); } private File getImageFileFromClasspath(String fileName) {