diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index bfe07b5..82f1a5a 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -64,7 +64,8 @@ public class OllamaAPI { /** * The request timeout in seconds for API calls. *

- * Default is 10 seconds. This value determines how long the client will wait for a response + * Default is 10 seconds. This value determines how long the client will wait + * for a response * from the Ollama server before timing out. */ @Setter @@ -73,7 +74,8 @@ public class OllamaAPI { /** * Enables or disables verbose logging of responses. *

- * If set to {@code true}, the API will log detailed information about requests and responses. + * If set to {@code true}, the API will log detailed information about requests + * and responses. * Default is {@code true}. */ @Setter @@ -82,7 +84,8 @@ public class OllamaAPI { /** * The maximum number of retries for tool calls during chat interactions. *

- * This value controls how many times the API will attempt to call a tool in the event of a failure. + * This value controls how many times the API will attempt to call a tool in the + * event of a failure. * Default is 3. */ @Setter @@ -91,7 +94,8 @@ public class OllamaAPI { /** * The number of retries to attempt when pulling a model from the Ollama server. *

- * If set to 0, no retries will be performed. If greater than 0, the API will retry pulling the model + * If set to 0, no retries will be performed. If greater than 0, the API will + * retry pulling the model * up to the specified number of times in case of failure. *

* Default is 0 (no retries). @@ -189,7 +193,10 @@ public class OllamaAPI { HttpClient httpClient = HttpClient.newHttpClient(); HttpRequest httpRequest = null; try { - httpRequest = getRequestBuilderDefault(new URI(url)).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET().build(); + httpRequest = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .GET().build(); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -216,7 +223,10 @@ public class OllamaAPI { public List listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = this.host + "/api/tags"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET() + .build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -244,10 +254,14 @@ public class OllamaAPI { * @throws URISyntaxException If there is an error creating the URI for the * HTTP request. */ - public List listModelsFromLibrary() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + public List listModelsFromLibrary() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = "https://ollama.com/library"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET() + .build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -262,7 +276,8 @@ public class OllamaAPI { Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type"); Elements popularTags = e.select("div > div > span"); Elements totalTags = e.select("div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type"); - Elements lastUpdatedTime = e.select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)"); + Elements lastUpdatedTime = e + .select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)"); if (names.first() == null || names.isEmpty()) { // if name cannot be extracted, skip. @@ -270,9 +285,12 @@ public class OllamaAPI { } Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName); model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse("")); - model.setPopularTags(Optional.of(popularTags).map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())).orElse(new ArrayList<>())); + model.setPopularTags(Optional.of(popularTags) + .map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())) + .orElse(new ArrayList<>())); model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse("")); - model.setTotalTags(Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0)); + model.setTotalTags( + Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0)); model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse("")); models.add(model); @@ -305,10 +323,14 @@ public class OllamaAPI { * the HTTP response. * @throws URISyntaxException if the URI format is incorrect or invalid. */ - public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel) + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName()); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET() + .build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -316,7 +338,8 @@ public class OllamaAPI { List libraryModelTags = new ArrayList<>(); if (statusCode == 200) { Document doc = Jsoup.parse(responseString); - Elements tagSections = doc.select("html > body > main > div > section > div > div > div:nth-child(n+2) > div"); + Elements tagSections = doc + .select("html > body > main > div > section > div > div > div:nth-child(n+2) > div"); for (Element e : tagSections) { Elements tags = e.select("div > a > div"); Elements tagsMetas = e.select("div > span"); @@ -329,8 +352,11 @@ public class OllamaAPI { } libraryModelTag.setName(libraryModel.getName()); Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag); - libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")).filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse("")); - libraryModelTag.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")).filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse("")); + libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")) + .filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse("")); + libraryModelTag + .setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("•")) + .filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse("")); libraryModelTags.add(libraryModelTag); } LibraryModelDetail libraryModelDetail = new LibraryModelDetail(); @@ -363,30 +389,41 @@ public class OllamaAPI { * @throws InterruptedException If the operation is interrupted. * @throws NoSuchElementException If the model or the tag is not found. */ - public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { List libraryModels = this.listModelsFromLibrary(); - LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName))); + LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)) + .findFirst().orElseThrow( + () -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName))); LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel); - return libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Tag '%s' for model '%s' not found", tag, modelName))); + return libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst() + .orElseThrow(() -> new NoSuchElementException( + String.format("Tag '%s' for model '%s' not found", tag, modelName))); } /** * Pull a model on the Ollama server from the list of available models. *

- * If {@code numberOfRetriesForModelPull} is greater than 0, this method will retry pulling the model - * up to the specified number of times if an {@link OllamaBaseException} occurs, using exponential backoff - * between retries (delay doubles after each failed attempt, starting at 1 second). + * If {@code numberOfRetriesForModelPull} is greater than 0, this method will + * retry pulling the model + * up to the specified number of times if an {@link OllamaBaseException} occurs, + * using exponential backoff + * between retries (delay doubles after each failed attempt, starting at 1 + * second). *

* The backoff is only applied between retries, not after the final attempt. * * @param modelName the name of the model - * @throws OllamaBaseException if the response indicates an error status or all retries fail + * @throws OllamaBaseException if the response indicates an error status or all + * retries fail * @throws IOException if an I/O error occurs during the HTTP request - * @throws InterruptedException if the operation is interrupted or the thread is interrupted during backoff + * @throws InterruptedException if the operation is interrupted or the thread is + * interrupted during backoff * @throws URISyntaxException if the URI for the request is malformed */ - public void pullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + public void pullModel(String modelName) + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { if (numberOfRetriesForModelPull == 0) { this.doPullModel(modelName); return; @@ -402,36 +439,47 @@ public class OllamaAPI { numberOfRetries++; } } - throw new OllamaBaseException("Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries"); + throw new OllamaBaseException( + "Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries"); } /** - * Handles retry logic for pullModel, including logging and backoff. + * Handles retry backoff for pullModel. */ private void handlePullRetry(String modelName, int currentRetry, int maxRetries, long baseDelayMillis) throws InterruptedException { - logger.error("Failed to pull model {}, retrying... (attempt {}/{})", modelName, currentRetry + 1, maxRetries); - if (currentRetry + 1 < maxRetries) { + int attempt = currentRetry + 1; + if (attempt < maxRetries) { long backoffMillis = baseDelayMillis * (1L << currentRetry); + logger.error("Failed to pull model {}, retrying in {} ms... (attempt {}/{})", + modelName, backoffMillis, attempt, maxRetries); try { Thread.sleep(backoffMillis); } catch (InterruptedException ie) { Thread.currentThread().interrupt(); throw ie; } + } else { + logger.error("Failed to pull model {} after {} attempts, no more retries.", modelName, maxRetries); } } - private void doPullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + + private void doPullModel(String modelName) + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String url = this.host + "/api/pull"; String jsonData = new ModelRequest(modelName).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); int statusCode = response.statusCode(); InputStream responseBodyStream = response.body(); String responseString = ""; boolean success = false; // Flag to check the pull success. - try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class); @@ -467,7 +515,10 @@ public class OllamaAPI { public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException { String url = this.host + "/api/version"; HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET().build(); + HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET() + .build(); HttpResponse response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseString = response.body(); @@ -492,7 +543,8 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public void pullModel(LibraryModelTag libraryModelTag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { + public void pullModel(LibraryModelTag libraryModelTag) + throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { String tagToPull = String.format("%s:%s", libraryModelTag.getName(), libraryModelTag.getTag()); pullModel(tagToPull); } @@ -507,10 +559,14 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException, URISyntaxException { + public ModelDetail getModelDetails(String modelName) + throws IOException, OllamaBaseException, InterruptedException, URISyntaxException { String url = this.host + "/api/show"; String jsonData = new ModelRequest(modelName).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -536,10 +592,14 @@ public class OllamaAPI { * @throws URISyntaxException if the URI for the request is malformed */ @Deprecated - public void createModelWithFilePath(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { + public void createModelWithFilePath(String modelName, String modelFilePath) + throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -573,10 +633,14 @@ public class OllamaAPI { * @throws URISyntaxException if the URI for the request is malformed */ @Deprecated - public void createModelWithModelFileContents(String modelName, String modelFileContents) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { + public void createModelWithModelFileContents(String modelName, String modelFileContents) + throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -603,10 +667,14 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public void createModel(CustomModelRequest customModelRequest) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { + public void createModel(CustomModelRequest customModelRequest) + throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/create"; String jsonData = customModelRequest.toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -633,10 +701,15 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public void deleteModel(String modelName, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { + public void deleteModel(String modelName, boolean ignoreIfNotPresent) + throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { String url = this.host + "/api/delete"; String jsonData = new ModelRequest(modelName).toString(); - HttpRequest request = getRequestBuilderDefault(new URI(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).build(); + HttpRequest request = getRequestBuilderDefault(new URI(url)) + .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .build(); HttpClient client = HttpClient.newHttpClient(); HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -661,7 +734,8 @@ public class OllamaAPI { * @deprecated Use {@link #embed(String, List)} instead. */ @Deprecated - public List generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException { + public List generateEmbeddings(String model, String prompt) + throws IOException, InterruptedException, OllamaBaseException { return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt)); } @@ -676,17 +750,21 @@ public class OllamaAPI { * @deprecated Use {@link #embed(OllamaEmbedRequestModel)} instead. */ @Deprecated - public List generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException { + public List generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) + throws IOException, InterruptedException, OllamaBaseException { URI uri = URI.create(this.host + "/api/embeddings"); String jsonData = modelRequest.toString(); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).POST(HttpRequest.BodyPublishers.ofString(jsonData)); + HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)); HttpRequest request = requestBuilder.build(); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); String responseBody = response.body(); if (statusCode == 200) { - OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class); + OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody, + OllamaEmbeddingResponseModel.class); return embeddingResponse.getEmbedding(); } else { throw new OllamaBaseException(statusCode + " - " + responseBody); @@ -703,7 +781,8 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaEmbedResponseModel embed(String model, List inputs) throws IOException, InterruptedException, OllamaBaseException { + public OllamaEmbedResponseModel embed(String model, List inputs) + throws IOException, InterruptedException, OllamaBaseException { return embed(new OllamaEmbedRequestModel(model, inputs)); } @@ -716,12 +795,15 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException { + public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) + throws IOException, InterruptedException, OllamaBaseException { URI uri = URI.create(this.host + "/api/embed"); String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest request = HttpRequest.newBuilder(uri).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + HttpRequest request = HttpRequest.newBuilder(uri) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); int statusCode = response.statusCode(); @@ -741,8 +823,12 @@ public class OllamaAPI { * * @param model the ollama model to ask the question to * @param prompt the prompt/question text - * @param raw if true no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API - * @param think if true the model will "think" step-by-step before generating the final response + * @param raw if true no formatting will be applied to the prompt. You + * may choose to use the raw parameter if you are + * specifying a full templated prompt in your request to + * the API + * @param think if true the model will "think" step-by-step before + * generating the final response * @param options the Options object - More @@ -755,7 +841,8 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options, + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); ollamaRequestModel.setRaw(raw); ollamaRequestModel.setThink(think); @@ -767,7 +854,8 @@ public class OllamaAPI { * Generates response using the specified AI model and prompt (in blocking * mode). *

- * Uses {@link #generate(String, String, boolean, boolean, Options, OllamaStreamHandler)} + * Uses + * {@link #generate(String, String, boolean, boolean, Options, OllamaStreamHandler)} * * @param model The name or identifier of the AI model to use for generating * the response. @@ -785,7 +873,8 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options) + throws OllamaBaseException, IOException, InterruptedException { return generate(model, prompt, raw, think, options, null); } @@ -806,7 +895,8 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted. */ @SuppressWarnings("LoggingSimilarMessage") - public OllamaResult generate(String model, String prompt, Map format) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generate(String model, String prompt, Map format) + throws OllamaBaseException, IOException, InterruptedException { URI uri = URI.create(this.host + "/api/generate"); Map requestBody = new HashMap<>(); @@ -818,11 +908,15 @@ public class OllamaAPI { String jsonData = Utils.getObjectMapper().writeValueAsString(requestBody); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest request = getRequestBuilderDefault(uri).header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); + HttpRequest request = getRequestBuilderDefault(uri) + .header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .POST(HttpRequest.BodyPublishers.ofString(jsonData)).build(); if (verbose) { try { - String prettyJson = Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(Utils.getObjectMapper().readValue(jsonData, Object.class)); + String prettyJson = Utils.getObjectMapper().writerWithDefaultPrettyPrinter() + .writeValueAsString(Utils.getObjectMapper().readValue(jsonData, Object.class)); logger.info("Asking model:\n{}", prettyJson); } catch (Exception e) { logger.info("Asking model: {}", jsonData); @@ -832,15 +926,18 @@ public class OllamaAPI { int statusCode = response.statusCode(); String responseBody = response.body(); if (statusCode == 200) { - OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, OllamaStructuredResult.class); - OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(), structuredResult.getResponseTime(), statusCode); + OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, + OllamaStructuredResult.class); + OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(), + structuredResult.getResponseTime(), statusCode); if (verbose) { logger.info("Model response:\n{}", ollamaResult); } return ollamaResult; } else { if (verbose) { - logger.info("Model response:\n{}", Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseBody)); + logger.info("Model response:\n{}", + Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseBody)); } throw new OllamaBaseException(statusCode + " - " + responseBody); } @@ -865,7 +962,8 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaToolsResult generateWithTools(String model, String prompt, boolean think, Options options) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + public OllamaToolsResult generateWithTools(String model, String prompt, boolean think, Options options) + throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { boolean raw = true; OllamaToolsResult toolResult = new OllamaToolsResult(); Map toolResults = new HashMap<>(); @@ -898,7 +996,8 @@ public class OllamaAPI { logger.warn("Response from model does not contain any tool calls. Returning the response as is."); return toolResult; } - toolFunctionCallSpecs = objectMapper.readValue(toolsResponse, objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class)); + toolFunctionCallSpecs = objectMapper.readValue(toolsResponse, + objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class)); } for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) { toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec)); @@ -922,7 +1021,8 @@ public class OllamaAPI { OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); ollamaRequestModel.setRaw(raw); URI uri = URI.create(this.host + "/api/generate"); - OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds); + OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer( + getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds); ollamaAsyncResultStreamer.start(); return ollamaAsyncResultStreamer; } @@ -947,7 +1047,8 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options, + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { List images = new ArrayList<>(); for (File imageFile : imageFiles) { images.add(encodeFileToBase64(imageFile)); @@ -967,7 +1068,8 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImageFiles(String model, String prompt, List imageFiles, Options options) + throws OllamaBaseException, IOException, InterruptedException { return generateWithImageFiles(model, prompt, imageFiles, options, null); } @@ -992,7 +1094,9 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options, + OllamaStreamHandler streamHandler) + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { List images = new ArrayList<>(); for (String imageURL : imageURLs) { images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL))); @@ -1013,7 +1117,8 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws URISyntaxException if the URI for the request is malformed */ - public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + public OllamaResult generateWithImageURLs(String model, String prompt, List imageURLs, Options options) + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { return generateWithImageURLs(model, prompt, imageURLs, options, null); } @@ -1037,7 +1142,8 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImages(String model, String prompt, List images, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImages(String model, String prompt, List images, Options options, + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { List encodedImages = new ArrayList<>(); for (byte[] image : images) { encodedImages.add(encodeByteArrayToBase64(image)); @@ -1058,7 +1164,8 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaResult generateWithImages(String model, String prompt, List images, Options options) throws OllamaBaseException, IOException, InterruptedException { + public OllamaResult generateWithImages(String model, String prompt, List images, Options options) + throws OllamaBaseException, IOException, InterruptedException { return generateWithImages(model, prompt, images, options, null); } @@ -1082,7 +1189,8 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ - public OllamaChatResult chat(String model, List messages) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + public OllamaChatResult chat(String model, List messages) + throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model); return chat(builder.withMessages(messages).build()); } @@ -1106,7 +1214,8 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ - public OllamaChatResult chat(OllamaChatRequest request) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + public OllamaChatResult chat(OllamaChatRequest request) + throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { return chat(request, null); } @@ -1132,7 +1241,8 @@ public class OllamaAPI { * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ - public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) + throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { return chatStreaming(request, new OllamaChatStreamObserver(streamHandler)); } @@ -1155,12 +1265,15 @@ public class OllamaAPI { * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted */ - public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { - OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, verbose); + public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) + throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { + OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, + verbose); OllamaChatResult result; // add all registered tools to Request - request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList())); + request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt) + .collect(Collectors.toList())); if (tokenHandler != null) { request.setStream(true); @@ -1181,7 +1294,8 @@ public class OllamaAPI { } Map arguments = toolCall.getFunction().getArguments(); Object res = toolFunction.apply(arguments); - request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]")); + request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, + "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]")); } if (tokenHandler != null) { @@ -1268,8 +1382,8 @@ public class OllamaAPI { for (Class provider : providers) { registerAnnotatedTools(provider.getDeclaredConstructor().newInstance()); } - } catch (InstantiationException | NoSuchMethodException | IllegalAccessException | - InvocationTargetException e) { + } catch (InstantiationException | NoSuchMethodException | IllegalAccessException + | InvocationTargetException e) { throw new RuntimeException(e); } } @@ -1309,12 +1423,22 @@ public class OllamaAPI { } String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName(); methodParams.put(propName, propType); - propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType).description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build()); + propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType) + .description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build()); } final Map params = propsBuilder.build(); - List reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()).map(Map.Entry::getKey).collect(Collectors.toList()); + List reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()) + .map(Map.Entry::getKey).collect(Collectors.toList()); - Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName).functionDescription(operationDesc).toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName).description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(params).required(reqProps).build()).build()).build()).build(); + Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName) + .functionDescription(operationDesc) + .toolPrompt(Tools.PromptFuncDefinition.builder().type("function") + .function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName) + .description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters + .builder().type("object").properties(params).required(reqProps).build()) + .build()) + .build()) + .build(); ReflectionalToolFunction reflectionalToolFunction = new ReflectionalToolFunction(object, m, methodParams); toolSpecification.setToolFunction(reflectionalToolFunction); @@ -1395,8 +1519,10 @@ public class OllamaAPI { * process. * @throws InterruptedException if the thread is interrupted during the request. */ - private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { - OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, verbose); + private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, + OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { + OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, + verbose); OllamaResult result; if (streamHandler != null) { ollamaRequestModel.setStream(true); @@ -1414,7 +1540,9 @@ public class OllamaAPI { * @return HttpRequest.Builder */ private HttpRequest.Builder getRequestBuilderDefault(URI uri) { - HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).timeout(Duration.ofSeconds(requestTimeoutSeconds)); + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri) + .header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON) + .timeout(Duration.ofSeconds(requestTimeoutSeconds)); if (isBasicAuthCredentialsSet()) { requestBuilder.header("Authorization", auth.getAuthHeaderValue()); } @@ -1439,7 +1567,8 @@ public class OllamaAPI { logger.debug("Invoking function {} with arguments {}", methodName, arguments); } if (function == null) { - throw new ToolNotFoundException("No such tool: " + methodName + ". Please register the tool before invoking it."); + throw new ToolNotFoundException( + "No such tool: " + methodName + ". Please register the tool before invoking it."); } return function.apply(arguments); } catch (Exception e) { diff --git a/src/main/java/io/github/ollama4j/models/request/BearerAuth.java b/src/main/java/io/github/ollama4j/models/request/BearerAuth.java index 8236042..4d876f2 100644 --- a/src/main/java/io/github/ollama4j/models/request/BearerAuth.java +++ b/src/main/java/io/github/ollama4j/models/request/BearerAuth.java @@ -2,18 +2,20 @@ package io.github.ollama4j.models.request; import lombok.AllArgsConstructor; import lombok.Data; +import lombok.EqualsAndHashCode; @Data @AllArgsConstructor +@EqualsAndHashCode(callSuper = false) public class BearerAuth extends Auth { - private String bearerToken; + private String bearerToken; - /** - * Get authentication header value. - * - * @return authentication header value with bearer token - */ - public String getAuthHeaderValue() { - return "Bearer "+ bearerToken; - } + /** + * Get authentication header value. + * + * @return authentication header value with bearer token + */ + public String getAuthHeaderValue() { + return "Bearer " + bearerToken; + } }