Refactor OllamaAPI and chat models to support 'thinking' responses

- Introduced a 'thinking' field in OllamaChatMessage to capture intermediate reasoning.
- Updated OllamaChatRequest to include a 'think' parameter for chat requests.
- Modified OllamaChatRequestBuilder to facilitate setting the 'think' parameter.
- Enhanced response handling in OllamaChatStreamObserver and OllamaGenerateStreamObserver to manage 'thinking' content.
- Updated integration tests to validate the new 'thinking' functionality in chat and generation methods.
This commit is contained in:
amithkoujalgi 2025-08-28 12:44:43 +05:30
parent 14642e9856
commit 8d9ee006ee
No known key found for this signature in database
GPG Key ID: E29A37746AF94B70
9 changed files with 182 additions and 207 deletions

View File

@ -137,8 +137,7 @@ public class OllamaAPI {
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = null; HttpRequest httpRequest = null;
try { try {
httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
.header("Content-type", "application/json").GET().build();
} catch (URISyntaxException e) { } catch (URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -168,8 +167,7 @@ public class OllamaAPI {
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = null; HttpRequest httpRequest = null;
try { try {
httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
.header("Content-type", "application/json").GET().build();
} catch (URISyntaxException e) { } catch (URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -196,8 +194,7 @@ public class OllamaAPI {
public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { public List<Model> listModels() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = this.host + "/api/tags"; String url = this.host + "/api/tags";
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
.header("Content-type", "application/json").GET().build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseString = response.body(); String responseString = response.body();
@ -225,12 +222,10 @@ public class OllamaAPI {
* @throws URISyntaxException If there is an error creating the URI for the * @throws URISyntaxException If there is an error creating the URI for the
* HTTP request. * HTTP request.
*/ */
public List<LibraryModel> listModelsFromLibrary() public List<LibraryModel> listModelsFromLibrary() throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = "https://ollama.com/library"; String url = "https://ollama.com/library";
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
.header("Content-type", "application/json").GET().build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseString = response.body(); String responseString = response.body();
@ -245,8 +240,7 @@ public class OllamaAPI {
Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type"); Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type");
Elements popularTags = e.select("div > div > span"); Elements popularTags = e.select("div > div > span");
Elements totalTags = e.select("div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type"); Elements totalTags = e.select("div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type");
Elements lastUpdatedTime = e Elements lastUpdatedTime = e.select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)");
.select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)");
if (names.first() == null || names.isEmpty()) { if (names.first() == null || names.isEmpty()) {
// if name cannot be extracted, skip. // if name cannot be extracted, skip.
@ -254,12 +248,9 @@ public class OllamaAPI {
} }
Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName); Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName);
model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse("")); model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse(""));
model.setPopularTags(Optional.of(popularTags) model.setPopularTags(Optional.of(popularTags).map(tags -> tags.stream().map(Element::text).collect(Collectors.toList())).orElse(new ArrayList<>()));
.map(tags -> tags.stream().map(Element::text).collect(Collectors.toList()))
.orElse(new ArrayList<>()));
model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse("")); model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse(""));
model.setTotalTags( model.setTotalTags(Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0));
Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0));
model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse("")); model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse(""));
models.add(model); models.add(model);
@ -292,12 +283,10 @@ public class OllamaAPI {
* the HTTP response. * the HTTP response.
* @throws URISyntaxException if the URI format is incorrect or invalid. * @throws URISyntaxException if the URI format is incorrect or invalid.
*/ */
public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel) public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName()); String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName());
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
.header("Content-type", "application/json").GET().build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseString = response.body(); String responseString = response.body();
@ -305,8 +294,7 @@ public class OllamaAPI {
List<LibraryModelTag> libraryModelTags = new ArrayList<>(); List<LibraryModelTag> libraryModelTags = new ArrayList<>();
if (statusCode == 200) { if (statusCode == 200) {
Document doc = Jsoup.parse(responseString); Document doc = Jsoup.parse(responseString);
Elements tagSections = doc Elements tagSections = doc.select("html > body > main > div > section > div > div > div:nth-child(n+2) > div");
.select("html > body > main > div > section > div > div > div:nth-child(n+2) > div");
for (Element e : tagSections) { for (Element e : tagSections) {
Elements tags = e.select("div > a > div"); Elements tags = e.select("div > a > div");
Elements tagsMetas = e.select("div > span"); Elements tagsMetas = e.select("div > span");
@ -319,11 +307,8 @@ public class OllamaAPI {
} }
libraryModelTag.setName(libraryModel.getName()); libraryModelTag.setName(libraryModel.getName());
Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag); Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag);
libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("")) libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("")).filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse(""));
.filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse("")); libraryModelTag.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split("")).filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse(""));
libraryModelTag
.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split(""))
.filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse(""));
libraryModelTags.add(libraryModelTag); libraryModelTags.add(libraryModelTag);
} }
LibraryModelDetail libraryModelDetail = new LibraryModelDetail(); LibraryModelDetail libraryModelDetail = new LibraryModelDetail();
@ -356,17 +341,11 @@ public class OllamaAPI {
* @throws InterruptedException If the operation is interrupted. * @throws InterruptedException If the operation is interrupted.
* @throws NoSuchElementException If the model or the tag is not found. * @throws NoSuchElementException If the model or the tag is not found.
*/ */
public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) public LibraryModelTag findModelTagFromLibrary(String modelName, String tag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
List<LibraryModel> libraryModels = this.listModelsFromLibrary(); List<LibraryModel> libraryModels = this.listModelsFromLibrary();
LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)) LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName)));
.findFirst().orElseThrow(
() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName)));
LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel); LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel);
LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream() LibraryModelTag libraryModelTag = libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst().orElseThrow(() -> new NoSuchElementException(String.format("Tag '%s' for model '%s' not found", tag, modelName)));
.filter(tagName -> tagName.getTag().equals(tag)).findFirst()
.orElseThrow(() -> new NoSuchElementException(
String.format("Tag '%s' for model '%s' not found", tag, modelName)));
return libraryModelTag; return libraryModelTag;
} }
@ -380,8 +359,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed * @throws URISyntaxException if the URI for the request is malformed
*/ */
public void pullModel(String modelName) public void pullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
if (numberOfRetriesForModelPull == 0) { if (numberOfRetriesForModelPull == 0) {
this.doPullModel(modelName); this.doPullModel(modelName);
} else { } else {
@ -395,28 +373,21 @@ public class OllamaAPI {
numberOfRetries++; numberOfRetries++;
} }
} }
throw new OllamaBaseException( throw new OllamaBaseException("Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries");
"Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries");
} }
} }
private void doPullModel(String modelName) private void doPullModel(String modelName) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String url = this.host + "/api/pull"; String url = this.host + "/api/pull";
String jsonData = new ModelRequest(modelName).toString(); String jsonData = new ModelRequest(modelName).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)) HttpRequest request = getRequestBuilderDefault(new URI(url)).POST(HttpRequest.BodyPublishers.ofString(jsonData)).header("Accept", "application/json").header("Content-type", "application/json").build();
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
.header("Accept", "application/json")
.header("Content-type", "application/json")
.build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); HttpResponse<InputStream> response = client.send(request, HttpResponse.BodyHandlers.ofInputStream());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
InputStream responseBodyStream = response.body(); InputStream responseBodyStream = response.body();
String responseString = ""; String responseString = "";
boolean success = false; // Flag to check the pull success. boolean success = false; // Flag to check the pull success.
try (BufferedReader reader = new BufferedReader( try (BufferedReader reader = new BufferedReader(new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
new InputStreamReader(responseBodyStream, StandardCharsets.UTF_8))) {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class); ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
@ -452,8 +423,7 @@ public class OllamaAPI {
public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException { public String getVersion() throws URISyntaxException, IOException, InterruptedException, OllamaBaseException {
String url = this.host + "/api/version"; String url = this.host + "/api/version";
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest httpRequest = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").GET().build();
.header("Content-type", "application/json").GET().build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseString = response.body(); String responseString = response.body();
@ -478,8 +448,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed * @throws URISyntaxException if the URI for the request is malformed
*/ */
public void pullModel(LibraryModelTag libraryModelTag) public void pullModel(LibraryModelTag libraryModelTag) throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String tagToPull = String.format("%s:%s", libraryModelTag.getName(), libraryModelTag.getTag()); String tagToPull = String.format("%s:%s", libraryModelTag.getName(), libraryModelTag.getTag());
pullModel(tagToPull); pullModel(tagToPull);
} }
@ -494,12 +463,10 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed * @throws URISyntaxException if the URI for the request is malformed
*/ */
public ModelDetail getModelDetails(String modelName) public ModelDetail getModelDetails(String modelName) throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
throws IOException, OllamaBaseException, InterruptedException, URISyntaxException {
String url = this.host + "/api/show"; String url = this.host + "/api/show";
String jsonData = new ModelRequest(modelName).toString(); String jsonData = new ModelRequest(modelName).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
.header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -525,13 +492,10 @@ public class OllamaAPI {
* @throws URISyntaxException if the URI for the request is malformed * @throws URISyntaxException if the URI for the request is malformed
*/ */
@Deprecated @Deprecated
public void createModelWithFilePath(String modelName, String modelFilePath) public void createModelWithFilePath(String modelName, String modelFilePath) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create"; String url = this.host + "/api/create";
String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString(); String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -565,13 +529,10 @@ public class OllamaAPI {
* @throws URISyntaxException if the URI for the request is malformed * @throws URISyntaxException if the URI for the request is malformed
*/ */
@Deprecated @Deprecated
public void createModelWithModelFileContents(String modelName, String modelFileContents) public void createModelWithModelFileContents(String modelName, String modelFileContents) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create"; String url = this.host + "/api/create";
String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString(); String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -598,13 +559,10 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed * @throws URISyntaxException if the URI for the request is malformed
*/ */
public void createModel(CustomModelRequest customModelRequest) public void createModel(CustomModelRequest customModelRequest) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create"; String url = this.host + "/api/create";
String jsonData = customModelRequest.toString(); String jsonData = customModelRequest.toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json") HttpRequest request = getRequestBuilderDefault(new URI(url)).header("Accept", "application/json").header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -631,13 +589,10 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed * @throws URISyntaxException if the URI for the request is malformed
*/ */
public void deleteModel(String modelName, boolean ignoreIfNotPresent) public void deleteModel(String modelName, boolean ignoreIfNotPresent) throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/delete"; String url = this.host + "/api/delete";
String jsonData = new ModelRequest(modelName).toString(); String jsonData = new ModelRequest(modelName).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url)) HttpRequest request = getRequestBuilderDefault(new URI(url)).method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).header("Accept", "application/json").header("Content-type", "application/json").build();
.method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8))
.header("Accept", "application/json").header("Content-type", "application/json").build();
HttpClient client = HttpClient.newHttpClient(); HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -662,8 +617,7 @@ public class OllamaAPI {
* @deprecated Use {@link #embed(String, List)} instead. * @deprecated Use {@link #embed(String, List)} instead.
*/ */
@Deprecated @Deprecated
public List<Double> generateEmbeddings(String model, String prompt) public List<Double> generateEmbeddings(String model, String prompt) throws IOException, InterruptedException, OllamaBaseException {
throws IOException, InterruptedException, OllamaBaseException {
return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt)); return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
} }
@ -678,20 +632,17 @@ public class OllamaAPI {
* @deprecated Use {@link #embed(OllamaEmbedRequestModel)} instead. * @deprecated Use {@link #embed(OllamaEmbedRequestModel)} instead.
*/ */
@Deprecated @Deprecated
public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException {
throws IOException, InterruptedException, OllamaBaseException {
URI uri = URI.create(this.host + "/api/embeddings"); URI uri = URI.create(this.host + "/api/embeddings");
String jsonData = modelRequest.toString(); String jsonData = modelRequest.toString();
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json") HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData));
.POST(HttpRequest.BodyPublishers.ofString(jsonData));
HttpRequest request = requestBuilder.build(); HttpRequest request = requestBuilder.build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseBody = response.body(); String responseBody = response.body();
if (statusCode == 200) { if (statusCode == 200) {
OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class);
OllamaEmbeddingResponseModel.class);
return embeddingResponse.getEmbedding(); return embeddingResponse.getEmbedding();
} else { } else {
throw new OllamaBaseException(statusCode + " - " + responseBody); throw new OllamaBaseException(statusCode + " - " + responseBody);
@ -708,8 +659,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaEmbedResponseModel embed(String model, List<String> inputs) public OllamaEmbedResponseModel embed(String model, List<String> inputs) throws IOException, InterruptedException, OllamaBaseException {
throws IOException, InterruptedException, OllamaBaseException {
return embed(new OllamaEmbedRequestModel(model, inputs)); return embed(new OllamaEmbedRequestModel(model, inputs));
} }
@ -722,14 +672,12 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) public OllamaEmbedResponseModel embed(OllamaEmbedRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException {
throws IOException, InterruptedException, OllamaBaseException {
URI uri = URI.create(this.host + "/api/embed"); URI uri = URI.create(this.host + "/api/embed");
String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest); String jsonData = Utils.getObjectMapper().writeValueAsString(modelRequest);
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json") HttpRequest request = HttpRequest.newBuilder(uri).header("Accept", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
.POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
@ -763,8 +711,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options, public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw); ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setThink(think); ollamaRequestModel.setThink(think);
@ -794,13 +741,14 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options) public OllamaResult generate(String model, String prompt, boolean raw, boolean think, Options options) throws OllamaBaseException, IOException, InterruptedException {
throws OllamaBaseException, IOException, InterruptedException {
return generate(model, prompt, raw, think, options, null); return generate(model, prompt, raw, think, options, null);
} }
/** /**
* Generates structured output from the specified AI model and prompt. * Generates structured output from the specified AI model and prompt.
* <p>
* Note: When formatting is specified, the 'think' parameter is not allowed.
* *
* @param model The name or identifier of the AI model to use for generating * @param model The name or identifier of the AI model to use for generating
* the response. * the response.
@ -813,8 +761,8 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request. * @throws IOException if an I/O error occurs during the HTTP request.
* @throws InterruptedException if the operation is interrupted. * @throws InterruptedException if the operation is interrupted.
*/ */
public OllamaResult generate(String model, String prompt, Map<String, Object> format) @SuppressWarnings("LoggingSimilarMessage")
throws OllamaBaseException, IOException, InterruptedException { public OllamaResult generate(String model, String prompt, Map<String, Object> format) throws OllamaBaseException, IOException, InterruptedException {
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
Map<String, Object> requestBody = new HashMap<>(); Map<String, Object> requestBody = new HashMap<>();
@ -826,23 +774,30 @@ public class OllamaAPI {
String jsonData = Utils.getObjectMapper().writeValueAsString(requestBody); String jsonData = Utils.getObjectMapper().writeValueAsString(requestBody);
HttpClient httpClient = HttpClient.newHttpClient(); HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest request = getRequestBuilderDefault(uri) HttpRequest request = getRequestBuilderDefault(uri).header("Accept", "application/json").header("Content-type", "application/json").POST(HttpRequest.BodyPublishers.ofString(jsonData)).build();
.header("Accept", "application/json")
.header("Content-type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonData))
.build();
if (verbose) {
try {
String prettyJson = Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(Utils.getObjectMapper().readValue(jsonData, Object.class));
logger.info("Asking model:\n{}", prettyJson);
} catch (Exception e) {
logger.info("Asking model: {}", jsonData);
}
}
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode(); int statusCode = response.statusCode();
String responseBody = response.body(); String responseBody = response.body();
if (statusCode == 200) { if (statusCode == 200) {
OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody, OllamaStructuredResult.class);
OllamaStructuredResult.class); OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(), structuredResult.getResponseTime(), statusCode);
OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(), structuredResult.getThinking(), if (verbose) {
structuredResult.getResponseTime(), statusCode); logger.info("Model response:\n{}", ollamaResult);
}
return ollamaResult; return ollamaResult;
} else { } else {
if (verbose) {
logger.info("Model response:\n{}", Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseBody));
}
throw new OllamaBaseException(statusCode + " - " + responseBody); throw new OllamaBaseException(statusCode + " - " + responseBody);
} }
} }
@ -866,8 +821,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaToolsResult generateWithTools(String model, String prompt, boolean think, Options options) public OllamaToolsResult generateWithTools(String model, String prompt, boolean think, Options options) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
boolean raw = true; boolean raw = true;
OllamaToolsResult toolResult = new OllamaToolsResult(); OllamaToolsResult toolResult = new OllamaToolsResult();
Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>(); Map<ToolFunctionCallSpec, Object> toolResults = new HashMap<>();
@ -900,9 +854,7 @@ public class OllamaAPI {
logger.warn("Response from model does not contain any tool calls. Returning the response as is."); logger.warn("Response from model does not contain any tool calls. Returning the response as is.");
return toolResult; return toolResult;
} }
toolFunctionCallSpecs = objectMapper.readValue( toolFunctionCallSpecs = objectMapper.readValue(toolsResponse, objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
toolsResponse,
objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
} }
for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) { for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec)); toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
@ -926,8 +878,7 @@ public class OllamaAPI {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw); ollamaRequestModel.setRaw(raw);
URI uri = URI.create(this.host + "/api/generate"); URI uri = URI.create(this.host + "/api/generate");
OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer( OllamaAsyncResultStreamer ollamaAsyncResultStreamer = new OllamaAsyncResultStreamer(getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds);
ollamaAsyncResultStreamer.start(); ollamaAsyncResultStreamer.start();
return ollamaAsyncResultStreamer; return ollamaAsyncResultStreamer;
} }
@ -952,8 +903,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options, public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
List<String> images = new ArrayList<>(); List<String> images = new ArrayList<>();
for (File imageFile : imageFiles) { for (File imageFile : imageFiles) {
images.add(encodeFileToBase64(imageFile)); images.add(encodeFileToBase64(imageFile));
@ -973,8 +923,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options) public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options) throws OllamaBaseException, IOException, InterruptedException {
throws OllamaBaseException, IOException, InterruptedException {
return generateWithImageFiles(model, prompt, imageFiles, options, null); return generateWithImageFiles(model, prompt, imageFiles, options, null);
} }
@ -999,9 +948,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed * @throws URISyntaxException if the URI for the request is malformed
*/ */
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options, public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
List<String> images = new ArrayList<>(); List<String> images = new ArrayList<>();
for (String imageURL : imageURLs) { for (String imageURL : imageURLs) {
images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL))); images.add(encodeByteArrayToBase64(Utils.loadImageBytesFromUrl(imageURL)));
@ -1022,8 +969,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed * @throws URISyntaxException if the URI for the request is malformed
*/ */
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options) public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options) throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
return generateWithImageURLs(model, prompt, imageURLs, options, null); return generateWithImageURLs(model, prompt, imageURLs, options, null);
} }
@ -1047,8 +993,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options, public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
List<String> encodedImages = new ArrayList<>(); List<String> encodedImages = new ArrayList<>();
for (byte[] image : images) { for (byte[] image : images) {
encodedImages.add(encodeByteArrayToBase64(image)); encodedImages.add(encodeByteArrayToBase64(image));
@ -1069,8 +1014,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options) public OllamaResult generateWithImages(String model, String prompt, List<byte[]> images, Options options) throws OllamaBaseException, IOException, InterruptedException {
throws OllamaBaseException, IOException, InterruptedException {
return generateWithImages(model, prompt, images, options, null); return generateWithImages(model, prompt, images, options, null);
} }
@ -1094,8 +1038,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails * @throws ToolInvocationException if the tool invocation fails
*/ */
public OllamaChatResult chat(String model, List<OllamaChatMessage> messages) public OllamaChatResult chat(String model, List<OllamaChatMessage> messages) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model);
return chat(builder.withMessages(messages).build()); return chat(builder.withMessages(messages).build());
} }
@ -1119,8 +1062,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails * @throws ToolInvocationException if the tool invocation fails
*/ */
public OllamaChatResult chat(OllamaChatRequest request) public OllamaChatResult chat(OllamaChatRequest request) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
return chat(request, null); return chat(request, null);
} }
@ -1146,8 +1088,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails * @throws ToolInvocationException if the tool invocation fails
*/ */
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
return chatStreaming(request, new OllamaChatStreamObserver(streamHandler)); return chatStreaming(request, new OllamaChatStreamObserver(streamHandler));
} }
@ -1170,15 +1111,12 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds, verbose);
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds,
verbose);
OllamaChatResult result; OllamaChatResult result;
// add all registered tools to Request // add all registered tools to Request
request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt) request.setTools(toolRegistry.getRegisteredSpecs().stream().map(Tools.ToolSpecification::getToolPrompt).collect(Collectors.toList()));
.collect(Collectors.toList()));
if (tokenHandler != null) { if (tokenHandler != null) {
request.setStream(true); request.setStream(true);
@ -1199,8 +1137,7 @@ public class OllamaAPI {
} }
Map<String, Object> arguments = toolCall.getFunction().getArguments(); Map<String, Object> arguments = toolCall.getFunction().getArguments();
Object res = toolFunction.apply(arguments); Object res = toolFunction.apply(arguments);
request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, request.getMessages().add(new OllamaChatMessage(OllamaChatMessageRole.TOOL, "[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]"));
"[TOOL_RESULTS]" + toolName + "(" + arguments.keySet() + ") : " + res + "[/TOOL_RESULTS]"));
} }
if (tokenHandler != null) { if (tokenHandler != null) {
@ -1276,8 +1213,8 @@ public class OllamaAPI {
for (Class<?> provider : providers) { for (Class<?> provider : providers) {
registerAnnotatedTools(provider.getDeclaredConstructor().newInstance()); registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
} }
} catch (InstantiationException | NoSuchMethodException | IllegalAccessException } catch (InstantiationException | NoSuchMethodException | IllegalAccessException |
| InvocationTargetException e) { InvocationTargetException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
@ -1317,22 +1254,12 @@ public class OllamaAPI {
} }
String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName(); String propName = !toolPropertyAnn.name().isBlank() ? toolPropertyAnn.name() : parameter.getName();
methodParams.put(propName, propType); methodParams.put(propName, propType);
propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType) propsBuilder.withProperty(propName, Tools.PromptFuncDefinition.Property.builder().type(propType).description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build());
.description(toolPropertyAnn.desc()).required(toolPropertyAnn.required()).build());
} }
final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build(); final Map<String, Tools.PromptFuncDefinition.Property> params = propsBuilder.build();
List<String> reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()) List<String> reqProps = params.entrySet().stream().filter(e -> e.getValue().isRequired()).map(Map.Entry::getKey).collect(Collectors.toList());
.map(Map.Entry::getKey).collect(Collectors.toList());
Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName) Tools.ToolSpecification toolSpecification = Tools.ToolSpecification.builder().functionName(operationName).functionDescription(operationDesc).toolPrompt(Tools.PromptFuncDefinition.builder().type("function").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName).description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(params).required(reqProps).build()).build()).build()).build();
.functionDescription(operationDesc)
.toolPrompt(Tools.PromptFuncDefinition.builder().type("function")
.function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(operationName)
.description(operationDesc).parameters(Tools.PromptFuncDefinition.Parameters
.builder().type("object").properties(params).required(reqProps).build())
.build())
.build())
.build();
ReflectionalToolFunction reflectionalToolFunction = new ReflectionalToolFunction(object, m, methodParams); ReflectionalToolFunction reflectionalToolFunction = new ReflectionalToolFunction(object, m, methodParams);
toolSpecification.setToolFunction(reflectionalToolFunction); toolSpecification.setToolFunction(reflectionalToolFunction);
@ -1413,10 +1340,8 @@ public class OllamaAPI {
* process. * process.
* @throws InterruptedException if the thread is interrupted during the request. * @throws InterruptedException if the thread is interrupted during the request.
*/ */
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, verbose);
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds,
verbose);
OllamaResult result; OllamaResult result;
if (streamHandler != null) { if (streamHandler != null) {
ollamaRequestModel.setStream(true); ollamaRequestModel.setStream(true);
@ -1434,8 +1359,7 @@ public class OllamaAPI {
* @return HttpRequest.Builder * @return HttpRequest.Builder
*/ */
private HttpRequest.Builder getRequestBuilderDefault(URI uri) { private HttpRequest.Builder getRequestBuilderDefault(URI uri) {
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json") HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(uri).header("Content-Type", "application/json").timeout(Duration.ofSeconds(requestTimeoutSeconds));
.timeout(Duration.ofSeconds(requestTimeoutSeconds));
if (isBasicAuthCredentialsSet()) { if (isBasicAuthCredentialsSet()) {
requestBuilder.header("Authorization", auth.getAuthHeaderValue()); requestBuilder.header("Authorization", auth.getAuthHeaderValue());
} }
@ -1460,8 +1384,7 @@ public class OllamaAPI {
logger.debug("Invoking function {} with arguments {}", methodName, arguments); logger.debug("Invoking function {} with arguments {}", methodName, arguments);
} }
if (function == null) { if (function == null) {
throw new ToolNotFoundException( throw new ToolNotFoundException("No such tool: " + methodName + ". Please register the tool before invoking it.");
"No such tool: " + methodName + ". Please register the tool before invoking it.");
} }
return function.apply(arguments); return function.apply(arguments);
} catch (Exception e) { } catch (Exception e) {

View File

@ -35,6 +35,8 @@ public class OllamaChatMessage {
@NonNull @NonNull
private String content; private String content;
private String thinking;
private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls; private @JsonProperty("tool_calls") List<OllamaChatToolCalls> toolCalls;
@JsonSerialize(using = FileToBase64Serializer.class) @JsonSerialize(using = FileToBase64Serializer.class)

View File

@ -13,31 +13,35 @@ import lombok.Setter;
* Defines a Request to use against the ollama /api/chat endpoint. * Defines a Request to use against the ollama /api/chat endpoint.
* *
* @see <a href= * @see <a href=
* "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate * "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Generate
* Chat Completion</a> * Chat Completion</a>
*/ */
@Getter @Getter
@Setter @Setter
public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequestBody { public class OllamaChatRequest extends OllamaCommonRequest implements OllamaRequestBody {
private List<OllamaChatMessage> messages; private List<OllamaChatMessage> messages;
private List<Tools.PromptFuncDefinition> tools; private List<Tools.PromptFuncDefinition> tools;
public OllamaChatRequest() {} private boolean think;
public OllamaChatRequest(String model, List<OllamaChatMessage> messages) { public OllamaChatRequest() {
this.model = model;
this.messages = messages;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof OllamaChatRequest)) {
return false;
} }
return this.toString().equals(o.toString()); public OllamaChatRequest(String model, boolean think, List<OllamaChatMessage> messages) {
} this.model = model;
this.messages = messages;
this.think = think;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof OllamaChatRequest)) {
return false;
}
return this.toString().equals(o.toString());
}
} }

View File

@ -22,7 +22,7 @@ public class OllamaChatRequestBuilder {
private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class); private static final Logger LOG = LoggerFactory.getLogger(OllamaChatRequestBuilder.class);
private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages) { private OllamaChatRequestBuilder(String model, List<OllamaChatMessage> messages) {
request = new OllamaChatRequest(model, messages); request = new OllamaChatRequest(model, false, messages);
} }
private OllamaChatRequest request; private OllamaChatRequest request;
@ -36,7 +36,7 @@ public class OllamaChatRequestBuilder {
} }
public void reset() { public void reset() {
request = new OllamaChatRequest(request.getModel(), new ArrayList<>()); request = new OllamaChatRequest(request.getModel(), request.isThink(), new ArrayList<>());
} }
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content) { public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content) {
@ -45,7 +45,7 @@ public class OllamaChatRequestBuilder {
public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls) { public OllamaChatRequestBuilder withMessage(OllamaChatMessageRole role, String content, List<OllamaChatToolCalls> toolCalls) {
List<OllamaChatMessage> messages = this.request.getMessages(); List<OllamaChatMessage> messages = this.request.getMessages();
messages.add(new OllamaChatMessage(role, content, toolCalls, null)); messages.add(new OllamaChatMessage(role, content, null, toolCalls, null));
return this; return this;
} }
@ -61,7 +61,7 @@ public class OllamaChatRequestBuilder {
} }
}).collect(Collectors.toList()); }).collect(Collectors.toList());
messages.add(new OllamaChatMessage(role, content, toolCalls, binaryImages)); messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages));
return this; return this;
} }
@ -81,7 +81,7 @@ public class OllamaChatRequestBuilder {
} }
} }
messages.add(new OllamaChatMessage(role, content, toolCalls, binaryImages)); messages.add(new OllamaChatMessage(role, content, null, toolCalls, binaryImages));
return this; return this;
} }
@ -114,4 +114,8 @@ public class OllamaChatRequestBuilder {
return this; return this;
} }
public OllamaChatRequestBuilder withThinking(boolean think) {
this.request.setThink(think);
return this;
}
} }

View File

@ -11,9 +11,22 @@ public class OllamaChatStreamObserver implements OllamaTokenHandler {
@Override @Override
public void accept(OllamaChatResponseModel token) { public void accept(OllamaChatResponseModel token) {
if (streamHandler != null) { if (streamHandler == null || token == null || token.getMessage() == null) {
message += token.getMessage().getContent(); return;
streamHandler.accept(message);
} }
String content = token.getMessage().getContent();
String thinking = token.getMessage().getThinking();
boolean hasContent = !content.isEmpty();
boolean hasThinking = thinking != null && !thinking.isEmpty();
if (hasThinking && !hasContent) {
message += thinking;
} else {
message += content;
}
streamHandler.accept(message);
} }
} }

View File

@ -24,8 +24,8 @@ public class OllamaGenerateStreamObserver {
String response = currentResponsePart.getResponse(); String response = currentResponsePart.getResponse();
String thinking = currentResponsePart.getThinking(); String thinking = currentResponsePart.getThinking();
boolean hasResponse = response != null && !response.trim().isEmpty(); boolean hasResponse = response != null && !response.isEmpty();
boolean hasThinking = thinking != null && !thinking.trim().isEmpty(); boolean hasThinking = thinking != null && !thinking.isEmpty();
if (!hasResponse && hasThinking) { if (!hasResponse && hasThinking) {
message = message + thinking; message = message + thinking;

View File

@ -58,7 +58,12 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
// thus, we null check the message and hope that the next streamed response has some message content again // thus, we null check the message and hope that the next streamed response has some message content again
OllamaChatMessage message = ollamaResponseModel.getMessage(); OllamaChatMessage message = ollamaResponseModel.getMessage();
if (message != null) { if (message != null) {
responseBuffer.append(message.getContent()); if (message.getThinking() != null) {
thinkingBuffer.append(message.getThinking());
}
else {
responseBuffer.append(message.getContent());
}
if (tokenHandler != null) { if (tokenHandler != null) {
tokenHandler.accept(ollamaResponseModel); tokenHandler.accept(ollamaResponseModel);
} }
@ -85,7 +90,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
.POST( .POST(
body.getBodyPublisher()); body.getBodyPublisher());
HttpRequest request = requestBuilder.build(); HttpRequest request = requestBuilder.build();
if (isVerbose()) LOG.info("Asking model: " + body); if (isVerbose()) LOG.info("Asking model: {}", body);
HttpResponse<InputStream> response = HttpResponse<InputStream> response =
httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream());
@ -129,6 +134,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
} }
if (finished && body.stream) { if (finished && body.stream) {
ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString()); ollamaChatResponseModel.getMessage().setContent(responseBuffer.toString());
ollamaChatResponseModel.getMessage().setThinking(thinkingBuffer.toString());
break; break;
} }
} }

View File

@ -125,7 +125,7 @@ public class OllamaGenerateEndpointCaller extends OllamaEndpointCaller {
} else { } else {
long endTime = System.currentTimeMillis(); long endTime = System.currentTimeMillis();
OllamaResult ollamaResult = OllamaResult ollamaResult =
new OllamaResult(responseBuffer.toString().trim(), thinkingBuffer.toString().trim(), endTime - startTime, statusCode); new OllamaResult(responseBuffer.toString(), thinkingBuffer.toString(), endTime - startTime, statusCode);
if (isVerbose()) LOG.info("Model response: " + ollamaResult); if (isVerbose()) LOG.info("Model response: " + ollamaResult);
return ollamaResult; return ollamaResult;
} }

View File

@ -53,6 +53,7 @@ public class OllamaAPIIntegrationTest {
private static final String CHAT_MODEL_LLAMA3 = "llama3"; private static final String CHAT_MODEL_LLAMA3 = "llama3";
private static final String IMAGE_MODEL_LLAVA = "llava"; private static final String IMAGE_MODEL_LLAVA = "llava";
private static final String THINKING_MODEL_GPT_OSS = "gpt-oss:20b"; private static final String THINKING_MODEL_GPT_OSS = "gpt-oss:20b";
private static final String THINKING_MODEL_QWEN = "qwen3:0.6b";
@BeforeAll @BeforeAll
public static void setUp() { public static void setUp() {
@ -220,7 +221,7 @@ public class OllamaAPIIntegrationTest {
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim()); assertEquals(sb.toString(), result.getResponse());
} }
@Test @Test
@ -441,29 +442,51 @@ public class OllamaAPIIntegrationTest {
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent()); assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getContent());
} }
@Test @Test
@Order(15) @Order(15)
void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException { void testChatWithStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
api.pullModel(CHAT_MODEL_QWEN_SMALL); api.pullModel(THINKING_MODEL_QWEN);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(CHAT_MODEL_QWEN_SMALL); OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_MODEL_QWEN);
OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?").build(); OllamaChatRequest requestModel = builder.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?").build();
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> { OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
LOG.info(s); LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length()); String substring = s.substring(sb.toString().length());
LOG.info(substring);
sb.append(substring); sb.append(substring);
}); });
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent()); assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString().trim(), chatResult.getResponseModel().getMessage().getContent().trim()); assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getContent());
}
@Test
@Order(15)
void testChatWithThinkingAndStream() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException, ToolInvocationException {
api.pullModel(THINKING_MODEL_QWEN);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(THINKING_MODEL_QWEN);
OllamaChatRequest requestModel = builder
.withMessage(OllamaChatMessageRole.USER, "What is the capital of France? And what's France's connection with Mona Lisa?")
.withThinking(true)
.withKeepAlive("0m")
.build();
StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> {
LOG.info(s);
String substring = s.substring(sb.toString().length());
sb.append(substring);
});
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertEquals(sb.toString(), chatResult.getResponseModel().getMessage().getThinking() + chatResult.getResponseModel().getMessage().getContent());
} }
@Test @Test
@ -503,14 +526,14 @@ public class OllamaAPIIntegrationTest {
OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> { OllamaResult result = api.generateWithImageFiles(IMAGE_MODEL_LLAVA, "What is in this image?", List.of(imageFile), new OptionsBuilder().build(), (s) -> {
LOG.info(s); LOG.info(s);
String substring = s.substring(sb.toString().length(), s.length()); String substring = s.substring(sb.toString().length());
LOG.info(substring); LOG.info(substring);
sb.append(substring); sb.append(substring);
}); });
assertNotNull(result); assertNotNull(result);
assertNotNull(result.getResponse()); assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
assertEquals(sb.toString().trim(), result.getResponse().trim()); assertEquals(sb.toString(), result.getResponse());
} }
@Test @Test
@ -532,13 +555,13 @@ public class OllamaAPIIntegrationTest {
@Test @Test
@Order(20) @Order(20)
void testGenerateWithThinkingAndStreamHandler() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException { void testGenerateWithThinkingAndStreamHandler() throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
api.pullModel(THINKING_MODEL_GPT_OSS); api.pullModel(THINKING_MODEL_QWEN);
boolean raw = false; boolean raw = false;
boolean thinking = true; boolean thinking = true;
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaResult result = api.generate(THINKING_MODEL_GPT_OSS, "Who are you?", raw, thinking, new OptionsBuilder().build(), (s) -> { OllamaResult result = api.generate(THINKING_MODEL_QWEN, "Who are you?", raw, thinking, new OptionsBuilder().build(), (s) -> {
LOG.info(s); LOG.info(s);
String substring = s.substring(sb.toString().length()); String substring = s.substring(sb.toString().length());
sb.append(substring); sb.append(substring);
@ -548,7 +571,7 @@ public class OllamaAPIIntegrationTest {
assertFalse(result.getResponse().isEmpty()); assertFalse(result.getResponse().isEmpty());
assertNotNull(result.getThinking()); assertNotNull(result.getThinking());
assertFalse(result.getThinking().isEmpty()); assertFalse(result.getThinking().isEmpty());
assertEquals(sb.toString().trim(), result.getThinking().trim() + result.getResponse().trim()); assertEquals(sb.toString(), result.getThinking() + result.getResponse());
} }
private File getImageFileFromClasspath(String fileName) { private File getImageFileFromClasspath(String fileName) {