refactor: clean up and deprecate unused methods in OllamaAPI and related classes

- Removed deprecated methods and unused imports from `OllamaAPI`.
- Updated method signatures to improve clarity and consistency.
- Refactored embedding request handling to utilize `OllamaEmbedRequestModel`.
- Adjusted integration tests to reflect changes in method usage and removed obsolete tests.
- Enhanced code readability by standardizing formatting and comments across various classes.
This commit is contained in:
Amith Koujalgi 2025-09-16 00:27:11 +05:30
parent 44c6236243
commit 656802b343
31 changed files with 558 additions and 959 deletions

View File

@ -9,8 +9,6 @@ import io.github.ollama4j.exceptions.ToolNotFoundException;
import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
import io.github.ollama4j.models.generate.OllamaGenerateRequest; import io.github.ollama4j.models.generate.OllamaGenerateRequest;
import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaStreamHandler;
import io.github.ollama4j.models.generate.OllamaTokenHandler; import io.github.ollama4j.models.generate.OllamaTokenHandler;
@ -25,10 +23,6 @@ import io.github.ollama4j.utils.Constants;
import io.github.ollama4j.utils.Options; import io.github.ollama4j.utils.Options;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import lombok.Setter; import lombok.Setter;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import org.jsoup.nodes.Element;
import org.jsoup.select.Elements;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -233,182 +227,6 @@ public class OllamaAPI {
} }
} }
/**
* Retrieves a list of models from the Ollama library. This method fetches the
* available models directly from Ollama
* library page, including model details such as the name, pull count, popular
* tags, tag count, and the time when model was updated.
*
* @return A list of {@link LibraryModel} objects representing the models
* available in the Ollama library.
* @throws OllamaBaseException If the HTTP request fails or the response is not
* successful (non-200 status code).
* @throws IOException If an I/O error occurs during the HTTP request
* or response processing.
* @throws InterruptedException If the thread executing the request is
* interrupted.
* @throws URISyntaxException If there is an error creating the URI for the
* HTTP request.
*/
public List<LibraryModel> listModelsFromLibrary()
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = "https://ollama.com/library";
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
.build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
List<LibraryModel> models = new ArrayList<>();
if (statusCode == 200) {
Document doc = Jsoup.parse(responseString);
Elements modelSections = doc.selectXpath("//*[@id='repo']/ul/li/a");
for (Element e : modelSections) {
LibraryModel model = new LibraryModel();
Elements names = e.select("div > h2 > div > span");
Elements desc = e.select("div > p");
Elements pullCounts = e.select("div:nth-of-type(2) > p > span:first-of-type > span:first-of-type");
Elements popularTags = e.select("div > div > span");
Elements totalTags = e.select("div:nth-of-type(2) > p > span:nth-of-type(2) > span:first-of-type");
Elements lastUpdatedTime = e
.select("div:nth-of-type(2) > p > span:nth-of-type(3) > span:nth-of-type(2)");
if (names.first() == null || names.isEmpty()) {
// if name cannot be extracted, skip.
continue;
}
Optional.ofNullable(names.first()).map(Element::text).ifPresent(model::setName);
model.setDescription(Optional.ofNullable(desc.first()).map(Element::text).orElse(""));
model.setPopularTags(Optional.of(popularTags)
.map(tags -> tags.stream().map(Element::text).collect(Collectors.toList()))
.orElse(new ArrayList<>()));
model.setPullCount(Optional.ofNullable(pullCounts.first()).map(Element::text).orElse(""));
model.setTotalTags(
Optional.ofNullable(totalTags.first()).map(Element::text).map(Integer::parseInt).orElse(0));
model.setLastUpdated(Optional.ofNullable(lastUpdatedTime.first()).map(Element::text).orElse(""));
models.add(model);
}
return models;
} else {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
}
/**
* Fetches the tags associated with a specific model from Ollama library.
* This method fetches the available model tags directly from Ollama library
* model page, including model tag name, size and time when model was last
* updated
* into a list of {@link LibraryModelTag} objects.
*
* @param libraryModel the {@link LibraryModel} object which contains the name
* of the library model
* for which the tags need to be fetched.
* @return a list of {@link LibraryModelTag} objects containing the extracted
* tags and their associated metadata.
* @throws OllamaBaseException if the HTTP response status code indicates an
* error (i.e., not 200 OK),
* or if there is any other issue during the
* request or response processing.
* @throws IOException if an input/output exception occurs during the
* HTTP request or response handling.
* @throws InterruptedException if the thread is interrupted while waiting for
* the HTTP response.
* @throws URISyntaxException if the URI format is incorrect or invalid.
*/
public LibraryModelDetail getLibraryModelDetails(LibraryModel libraryModel)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
String url = String.format("https://ollama.com/library/%s/tags", libraryModel.getName());
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest httpRequest = getRequestBuilderDefault(new URI(url))
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON).GET()
.build();
HttpResponse<String> response = httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
List<LibraryModelTag> libraryModelTags = new ArrayList<>();
if (statusCode == 200) {
Document doc = Jsoup.parse(responseString);
Elements tagSections = doc
.select("html > body > main > div > section > div > div > div:nth-child(n+2) > div");
for (Element e : tagSections) {
Elements tags = e.select("div > a > div");
Elements tagsMetas = e.select("div > span");
LibraryModelTag libraryModelTag = new LibraryModelTag();
if (tags.first() == null || tags.isEmpty()) {
// if tag cannot be extracted, skip.
continue;
}
libraryModelTag.setName(libraryModel.getName());
Optional.ofNullable(tags.first()).map(Element::text).ifPresent(libraryModelTag::setTag);
libraryModelTag.setSize(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split(""))
.filter(parts -> parts.length > 1).map(parts -> parts[1].trim()).orElse(""));
libraryModelTag
.setLastUpdated(Optional.ofNullable(tagsMetas.first()).map(element -> element.text().split(""))
.filter(parts -> parts.length > 1).map(parts -> parts[2].trim()).orElse(""));
libraryModelTags.add(libraryModelTag);
}
LibraryModelDetail libraryModelDetail = new LibraryModelDetail();
libraryModelDetail.setModel(libraryModel);
libraryModelDetail.setTags(libraryModelTags);
return libraryModelDetail;
} else {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
}
/**
* Finds a specific model using model name and tag from Ollama library.
* <p>
* <b>Deprecated:</b> This method relies on the HTML structure of the Ollama
* website,
* which is subject to change at any time. As a result, it is difficult to keep
* this API
* method consistently updated and reliable. Therefore, this method is
* deprecated and
* may be removed in future releases.
* <p>
* This method retrieves the model from the Ollama library by its name, then
* fetches its tags.
* It searches through the tags of the model to find one that matches the
* specified tag name.
* If the model or the tag is not found, it throws a
* {@link NoSuchElementException}.
*
* @param modelName The name of the model to search for in the library.
* @param tag The tag name to search for within the specified model.
* @return The {@link LibraryModelTag} associated with the specified model and
* tag.
* @throws OllamaBaseException If there is a problem with the Ollama library
* operations.
* @throws IOException If an I/O error occurs during the operation.
* @throws URISyntaxException If there is an error with the URI syntax.
* @throws InterruptedException If the operation is interrupted.
* @throws NoSuchElementException If the model or the tag is not found.
* @deprecated This method relies on the HTML structure of the Ollama website,
* which can change at any time and break this API. It is deprecated
* and may be removed in the future.
*/
@Deprecated
public LibraryModelTag findModelTagFromLibrary(String modelName, String tag)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
List<LibraryModel> libraryModels = this.listModelsFromLibrary();
LibraryModel libraryModel = libraryModels.stream().filter(model -> model.getName().equals(modelName))
.findFirst().orElseThrow(
() -> new NoSuchElementException(String.format("Model by name '%s' not found", modelName)));
LibraryModelDetail libraryModelDetail = this.getLibraryModelDetails(libraryModel);
return libraryModelDetail.getTags().stream().filter(tagName -> tagName.getTag().equals(tag)).findFirst()
.orElseThrow(() -> new NoSuchElementException(
String.format("Tag '%s' for model '%s' not found", tag, modelName)));
}
/** /**
* Pull a model on the Ollama server from the list of <a * Pull a model on the Ollama server from the list of <a
* href="https://ollama.ai/library">available models</a>. * href="https://ollama.ai/library">available models</a>.
@ -584,80 +402,6 @@ public class OllamaAPI {
} }
} }
/**
* Create a custom model from a model file. Read more about custom model file
* creation <a
* href=
* "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md">here</a>.
*
* @param modelName the name of the custom model to be created.
* @param modelFilePath the path to model file that exists on the Ollama server.
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed
*/
@Deprecated
public void createModelWithFilePath(String modelName, String modelFilePath)
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create";
String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url))
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
// FIXME: Ollama API returns HTTP status code 200 for model creation failure
// cases. Correct this
// if the issue is fixed in the Ollama API server.
if (responseString.contains("error")) {
throw new OllamaBaseException(responseString);
}
LOG.debug(responseString);
}
/**
* Create a custom model from a model file. Read more about custom model file
* creation <a
* href=
* "https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md">here</a>.
*
* @param modelName the name of the custom model to be created.
* @param modelFileContents the path to model file that exists on the Ollama
* server.
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
* @throws URISyntaxException if the URI for the request is malformed
*/
@Deprecated
public void createModelWithModelFileContents(String modelName, String modelFileContents)
throws IOException, InterruptedException, OllamaBaseException, URISyntaxException {
String url = this.host + "/api/create";
String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url))
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.header(Constants.HttpConstants.HEADER_KEY_CONTENT_TYPE, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)).build();
HttpClient client = HttpClient.newHttpClient();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseString = response.body();
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
if (responseString.contains("error")) {
throw new OllamaBaseException(responseString);
}
LOG.debug(responseString);
}
/** /**
* Create a custom model. Read more about custom model creation <a * Create a custom model. Read more about custom model creation <a
* href= * href=
@ -722,70 +466,6 @@ public class OllamaAPI {
} }
} }
/**
* Generate embeddings for a given text from a model
*
* @param model name of model to generate embeddings from
* @param prompt text to generate embeddings for
* @return embeddings
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
* @deprecated Use {@link #embed(String, List)} instead.
*/
@Deprecated
public List<Double> generateEmbeddings(String model, String prompt)
throws IOException, InterruptedException, OllamaBaseException {
return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
}
/**
* Generate embeddings using a {@link OllamaEmbeddingsRequestModel}.
*
* @param modelRequest request for '/api/embeddings' endpoint
* @return embeddings
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
* @deprecated Use {@link #embed(OllamaEmbedRequestModel)} instead.
*/
@Deprecated
public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest)
throws IOException, InterruptedException, OllamaBaseException {
URI uri = URI.create(this.host + "/api/embeddings");
String jsonData = modelRequest.toString();
HttpClient httpClient = HttpClient.newHttpClient();
HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri)
.header(Constants.HttpConstants.HEADER_KEY_ACCEPT, Constants.HttpConstants.APPLICATION_JSON)
.POST(HttpRequest.BodyPublishers.ofString(jsonData));
HttpRequest request = requestBuilder.build();
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
int statusCode = response.statusCode();
String responseBody = response.body();
if (statusCode == 200) {
OllamaEmbeddingResponseModel embeddingResponse = Utils.getObjectMapper().readValue(responseBody,
OllamaEmbeddingResponseModel.class);
return embeddingResponse.getEmbedding();
} else {
throw new OllamaBaseException(statusCode + " - " + responseBody);
}
}
/**
* Generate embeddings for a given text from a model
*
* @param model name of model to generate embeddings from
* @param inputs text/s to generate embeddings for
* @return embeddings
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
*/
public OllamaEmbedResponseModel embed(String model, List<String> inputs)
throws IOException, InterruptedException, OllamaBaseException {
return embed(new OllamaEmbedRequestModel(model, inputs));
}
/** /**
* Generate embeddings using a {@link OllamaEmbedRequestModel}. * Generate embeddings using a {@link OllamaEmbedRequestModel}.
* *
@ -1068,7 +748,7 @@ public class OllamaAPI {
* </p> * </p>
* *
* <pre>{@code * <pre>{@code
* OllamaAsyncResultStreamer resultStreamer = ollamaAPI.generateAsync("gpt-oss:20b", "Who are you", false, true); * OllamaAsyncResultStreamer resultStreamer = ollamaAPI.generate("gpt-oss:20b", "Who are you", false, true);
* int pollIntervalMilliseconds = 1000; * int pollIntervalMilliseconds = 1000;
* while (true) { * while (true) {
* String thinkingTokens = resultStreamer.getThinkingResponseStream().poll(); * String thinkingTokens = resultStreamer.getThinkingResponseStream().poll();
@ -1155,86 +835,7 @@ public class OllamaAPI {
} }
/** /**
* Ask a question to a model based on a given message stack (i.e. a chat * Ask a question to a model using an {@link OllamaChatRequest} and set up streaming response. This can be
* history). Creates a synchronous call to the api
* 'api/chat'.
*
* @param model the ollama model to ask the question to
* @param messages chat history / message stack to send to the model
* @return {@link OllamaChatResult} containing the api response and the message
* history including the newly acquired assistant response.
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or
* network
* issues happen
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP
* request
* @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails
*/
public OllamaChatResult chat(String model, List<OllamaChatMessage> messages)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(model);
return chat(builder.withMessages(messages).build());
}
/**
* Ask a question to a model using an {@link OllamaChatRequest}. This can be
* constructed using an {@link OllamaChatRequestBuilder}.
* <p>
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
*
* @param request request object to be sent to the server
* @return {@link OllamaChatResult}
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or
* network
* issues happen
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP
* request
* @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails
*/
public OllamaChatResult chat(OllamaChatRequest request)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
return chat(request, null, null);
}
/**
* Ask a question to a model using an {@link OllamaChatRequest}. This can be
* constructed using an {@link OllamaChatRequestBuilder}.
* <p>
* Hint: the OllamaChatRequestModel#getStream() property is not implemented.
*
* @param request request object to be sent to the server
* @param responseStreamHandler callback handler to handle the last message from
* stream
* @param thinkingStreamHandler callback handler to handle the last thinking
* message from stream
* @return {@link OllamaChatResult}
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or
* network
* issues happen
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP
* request
* @throws InterruptedException if the operation is interrupted
* @throws ToolInvocationException if the tool invocation fails
*/
public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler thinkingStreamHandler,
OllamaStreamHandler responseStreamHandler)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
return chatStreaming(request, new OllamaChatStreamObserver(thinkingStreamHandler, responseStreamHandler));
}
/**
* Ask a question to a model using an {@link OllamaChatRequest}. This can be
* constructed using an {@link OllamaChatRequestBuilder}. * constructed using an {@link OllamaChatRequestBuilder}.
* <p> * <p>
* Hint: the OllamaChatRequestModel#getStream() property is not implemented. * Hint: the OllamaChatRequestModel#getStream() property is not implemented.
@ -1252,7 +853,7 @@ public class OllamaAPI {
* @throws IOException if an I/O error occurs during the HTTP request * @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaChatResult chatStreaming(OllamaChatRequest request, OllamaTokenHandler tokenHandler) public OllamaChatResult chat(OllamaChatRequest request, OllamaTokenHandler tokenHandler)
throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException { throws OllamaBaseException, IOException, InterruptedException, ToolInvocationException {
OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds); OllamaChatEndpointCaller requestCaller = new OllamaChatEndpointCaller(host, auth, requestTimeoutSeconds);
OllamaChatResult result; OllamaChatResult result;

View File

@ -38,17 +38,17 @@ public class OllamaChatResult {
} }
@Deprecated @Deprecated
public String getResponse(){ public String getResponse() {
return responseModel != null ? responseModel.getMessage().getContent() : ""; return responseModel != null ? responseModel.getMessage().getContent() : "";
} }
@Deprecated @Deprecated
public int getHttpStatusCode(){ public int getHttpStatusCode() {
return 200; return 200;
} }
@Deprecated @Deprecated
public long getResponseTime(){ public long getResponseTime() {
return responseModel != null ? responseModel.getTotalDuration() : 0L; return responseModel != null ? responseModel.getTotalDuration() : 0L;
} }
} }

View File

@ -12,24 +12,24 @@ public class OllamaEmbedRequestBuilder {
private final OllamaEmbedRequestModel request; private final OllamaEmbedRequestModel request;
private OllamaEmbedRequestBuilder(String model, List<String> input) { private OllamaEmbedRequestBuilder(String model, List<String> input) {
this.request = new OllamaEmbedRequestModel(model,input); this.request = new OllamaEmbedRequestModel(model, input);
} }
public static OllamaEmbedRequestBuilder getInstance(String model, String... input){ public static OllamaEmbedRequestBuilder getInstance(String model, String... input) {
return new OllamaEmbedRequestBuilder(model, List.of(input)); return new OllamaEmbedRequestBuilder(model, List.of(input));
} }
public OllamaEmbedRequestBuilder withOptions(Options options){ public OllamaEmbedRequestBuilder withOptions(Options options) {
this.request.setOptions(options.getOptionsMap()); this.request.setOptions(options.getOptionsMap());
return this; return this;
} }
public OllamaEmbedRequestBuilder withKeepAlive(String keepAlive){ public OllamaEmbedRequestBuilder withKeepAlive(String keepAlive) {
this.request.setKeepAlive(keepAlive); this.request.setKeepAlive(keepAlive);
return this; return this;
} }
public OllamaEmbedRequestBuilder withoutTruncate(){ public OllamaEmbedRequestBuilder withoutTruncate() {
this.request.setTruncate(false); this.request.setTruncate(false);
return this; return this;
} }

View File

@ -7,7 +7,7 @@ import java.util.List;
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Data @Data
@Deprecated(since="1.0.90") @Deprecated(since = "1.0.90")
public class OllamaEmbeddingResponseModel { public class OllamaEmbeddingResponseModel {
@JsonProperty("embedding") @JsonProperty("embedding")
private List<Double> embedding; private List<Double> embedding;

View File

@ -2,29 +2,29 @@ package io.github.ollama4j.models.embeddings;
import io.github.ollama4j.utils.Options; import io.github.ollama4j.utils.Options;
@Deprecated(since="1.0.90") @Deprecated(since = "1.0.90")
public class OllamaEmbeddingsRequestBuilder { public class OllamaEmbeddingsRequestBuilder {
private OllamaEmbeddingsRequestBuilder(String model, String prompt){ private OllamaEmbeddingsRequestBuilder(String model, String prompt) {
request = new OllamaEmbeddingsRequestModel(model, prompt); request = new OllamaEmbeddingsRequestModel(model, prompt);
} }
private OllamaEmbeddingsRequestModel request; private OllamaEmbeddingsRequestModel request;
public static OllamaEmbeddingsRequestBuilder getInstance(String model, String prompt){ public static OllamaEmbeddingsRequestBuilder getInstance(String model, String prompt) {
return new OllamaEmbeddingsRequestBuilder(model, prompt); return new OllamaEmbeddingsRequestBuilder(model, prompt);
} }
public OllamaEmbeddingsRequestModel build(){ public OllamaEmbeddingsRequestModel build() {
return request; return request;
} }
public OllamaEmbeddingsRequestBuilder withOptions(Options options){ public OllamaEmbeddingsRequestBuilder withOptions(Options options) {
this.request.setOptions(options.getOptionsMap()); this.request.setOptions(options.getOptionsMap());
return this; return this;
} }
public OllamaEmbeddingsRequestBuilder withKeepAlive(String keepAlive){ public OllamaEmbeddingsRequestBuilder withKeepAlive(String keepAlive) {
this.request.setKeepAlive(keepAlive); this.request.setKeepAlive(keepAlive);
return this; return this;
} }

View File

@ -14,7 +14,7 @@ import static io.github.ollama4j.utils.Utils.getObjectMapper;
@Data @Data
@RequiredArgsConstructor @RequiredArgsConstructor
@NoArgsConstructor @NoArgsConstructor
@Deprecated(since="1.0.90") @Deprecated(since = "1.0.90")
public class OllamaEmbeddingsRequestModel { public class OllamaEmbeddingsRequestModel {
@NonNull @NonNull
private String model; private String model;

View File

@ -7,7 +7,6 @@ import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import java.util.List; import java.util.List;
import java.util.Map;
@Getter @Getter
@Setter @Setter

View File

@ -8,46 +8,46 @@ import io.github.ollama4j.utils.Options;
*/ */
public class OllamaGenerateRequestBuilder { public class OllamaGenerateRequestBuilder {
private OllamaGenerateRequestBuilder(String model, String prompt){ private OllamaGenerateRequestBuilder(String model, String prompt) {
request = new OllamaGenerateRequest(model, prompt); request = new OllamaGenerateRequest(model, prompt);
} }
private OllamaGenerateRequest request; private OllamaGenerateRequest request;
public static OllamaGenerateRequestBuilder getInstance(String model){ public static OllamaGenerateRequestBuilder getInstance(String model) {
return new OllamaGenerateRequestBuilder(model,""); return new OllamaGenerateRequestBuilder(model, "");
} }
public OllamaGenerateRequest build(){ public OllamaGenerateRequest build() {
return request; return request;
} }
public OllamaGenerateRequestBuilder withPrompt(String prompt){ public OllamaGenerateRequestBuilder withPrompt(String prompt) {
request.setPrompt(prompt); request.setPrompt(prompt);
return this; return this;
} }
public OllamaGenerateRequestBuilder withGetJsonResponse(){ public OllamaGenerateRequestBuilder withGetJsonResponse() {
this.request.setFormat("json"); this.request.setFormat("json");
return this; return this;
} }
public OllamaGenerateRequestBuilder withOptions(Options options){ public OllamaGenerateRequestBuilder withOptions(Options options) {
this.request.setOptions(options.getOptionsMap()); this.request.setOptions(options.getOptionsMap());
return this; return this;
} }
public OllamaGenerateRequestBuilder withTemplate(String template){ public OllamaGenerateRequestBuilder withTemplate(String template) {
this.request.setTemplate(template); this.request.setTemplate(template);
return this; return this;
} }
public OllamaGenerateRequestBuilder withStreaming(){ public OllamaGenerateRequestBuilder withStreaming() {
this.request.setStream(true); this.request.setStream(true);
return this; return this;
} }
public OllamaGenerateRequestBuilder withKeepAlive(String keepAlive){ public OllamaGenerateRequestBuilder withKeepAlive(String keepAlive) {
this.request.setKeepAlive(keepAlive); this.request.setKeepAlive(keepAlive);
return this; return this;
} }

View File

@ -61,8 +61,7 @@ public class OllamaChatEndpointCaller extends OllamaEndpointCaller {
if (message != null) { if (message != null) {
if (message.getThinking() != null) { if (message.getThinking() != null) {
thinkingBuffer.append(message.getThinking()); thinkingBuffer.append(message.getThinking());
} } else {
else {
responseBuffer.append(message.getContent()); responseBuffer.append(message.getContent());
} }
if (tokenHandler != null) { if (tokenHandler != null) {

View File

@ -3,8 +3,6 @@ package io.github.ollama4j.models.request;
import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.github.ollama4j.utils.BooleanToJsonFormatFlagSerializer;
import io.github.ollama4j.utils.Utils; import io.github.ollama4j.utils.Utils;
import lombok.Data; import lombok.Data;
@ -15,7 +13,7 @@ import java.util.Map;
public abstract class OllamaCommonRequest { public abstract class OllamaCommonRequest {
protected String model; protected String model;
// @JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class) // @JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class)
// this can either be set to format=json or format={"key1": "val1", "key2": "val2"} // this can either be set to format=json or format={"key1": "val1", "key2": "val2"}
@JsonProperty(value = "format", required = false, defaultValue = "json") @JsonProperty(value = "format", required = false, defaultValue = "json")
protected Object format; protected Object format;

View File

@ -11,8 +11,7 @@ import java.util.Map;
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true) @JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaToolCallsFunction public class OllamaToolCallsFunction {
{
private String name; private String name;
private Map<String,Object> arguments; private Map<String, Object> arguments;
} }

View File

@ -15,17 +15,17 @@ import java.util.Map;
@Setter @Setter
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
public class ReflectionalToolFunction implements ToolFunction{ public class ReflectionalToolFunction implements ToolFunction {
private Object functionHolder; private Object functionHolder;
private Method function; private Method function;
private LinkedHashMap<String,String> propertyDefinition; private LinkedHashMap<String, String> propertyDefinition;
@Override @Override
public Object apply(Map<String, Object> arguments) { public Object apply(Map<String, Object> arguments) {
LinkedHashMap<String, Object> argumentsCopy = new LinkedHashMap<>(this.propertyDefinition); LinkedHashMap<String, Object> argumentsCopy = new LinkedHashMap<>(this.propertyDefinition);
for (Map.Entry<String,String> param : this.propertyDefinition.entrySet()){ for (Map.Entry<String, String> param : this.propertyDefinition.entrySet()) {
argumentsCopy.replace(param.getKey(),typeCast(arguments.get(param.getKey()),param.getValue())); argumentsCopy.replace(param.getKey(), typeCast(arguments.get(param.getKey()), param.getValue()));
} }
try { try {
return function.invoke(functionHolder, argumentsCopy.values().toArray()); return function.invoke(functionHolder, argumentsCopy.values().toArray());
@ -35,7 +35,7 @@ public class ReflectionalToolFunction implements ToolFunction{
} }
private Object typeCast(Object inputValue, String className) { private Object typeCast(Object inputValue, String className) {
if(className == null || inputValue == null) { if (className == null || inputValue == null) {
return null; return null;
} }
String inputValueString = inputValue.toString(); String inputValueString = inputValue.toString();

View File

@ -17,12 +17,12 @@ public interface OllamaRequestBody {
* @return JSON representation of a OllamaRequest * @return JSON representation of a OllamaRequest
*/ */
@JsonIgnore @JsonIgnore
default BodyPublisher getBodyPublisher(){ default BodyPublisher getBodyPublisher() {
try { try {
return BodyPublishers.ofString( return BodyPublishers.ofString(
Utils.getObjectMapper().writeValueAsString(this)); Utils.getObjectMapper().writeValueAsString(this));
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new IllegalArgumentException("Request not Body convertible.",e); throw new IllegalArgumentException("Request not Body convertible.", e);
} }
} }
} }

View File

@ -4,7 +4,9 @@ import lombok.Data;
import java.util.Map; import java.util.Map;
/** Class for options for Ollama model. */ /**
* Class for options for Ollama model.
*/
@Data @Data
public class Options { public class Options {

View File

@ -2,12 +2,16 @@ package io.github.ollama4j.utils;
import java.util.HashMap; import java.util.HashMap;
/** Builder class for creating options for Ollama model. */ /**
* Builder class for creating options for Ollama model.
*/
public class OptionsBuilder { public class OptionsBuilder {
private final Options options; private final Options options;
/** Constructs a new OptionsBuilder with an empty options map. */ /**
* Constructs a new OptionsBuilder with an empty options map.
*/
public OptionsBuilder() { public OptionsBuilder() {
this.options = new Options(new HashMap<>()); this.options = new Options(new HashMap<>());
} }
@ -220,6 +224,7 @@ public class OptionsBuilder {
/** /**
* Allows passing an option not formally supported by the library * Allows passing an option not formally supported by the library
*
* @param name The option name for the parameter. * @param name The option name for the parameter.
* @param value The value for the "{name}" parameter. * @param value The value for the "{name}" parameter.
* @return The updated OptionsBuilder. * @return The updated OptionsBuilder.
@ -234,7 +239,6 @@ public class OptionsBuilder {
} }
/** /**
* Builds the options map. * Builds the options map.
* *

View File

@ -20,7 +20,9 @@ public class PromptBuilder {
private final StringBuilder prompt; private final StringBuilder prompt;
/** Constructs a new {@code PromptBuilder} with an empty prompt. */ /**
* Constructs a new {@code PromptBuilder} with an empty prompt.
*/
public PromptBuilder() { public PromptBuilder() {
this.prompt = new StringBuilder(); this.prompt = new StringBuilder();
} }

View File

@ -4,8 +4,8 @@ import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.exceptions.ToolInvocationException; import io.github.ollama4j.exceptions.ToolInvocationException;
import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import io.github.ollama4j.models.response.LibraryModel;
import io.github.ollama4j.models.response.Model; import io.github.ollama4j.models.response.Model;
import io.github.ollama4j.models.response.ModelDetail; import io.github.ollama4j.models.response.ModelDetail;
import io.github.ollama4j.models.response.OllamaResult; import io.github.ollama4j.models.response.OllamaResult;
@ -113,15 +113,6 @@ class OllamaAPIIntegrationTest {
assertTrue(models.size() >= 0, "Models list should not be empty"); assertTrue(models.size() >= 0, "Models list should not be empty");
} }
@Test
@Order(2)
void testListModelsFromLibrary()
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
List<LibraryModel> models = api.listModelsFromLibrary();
assertNotNull(models);
assertFalse(models.isEmpty());
}
@Test @Test
@Order(3) @Order(3)
void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException { void testPullModelAPI() throws URISyntaxException, IOException, OllamaBaseException, InterruptedException {
@ -144,8 +135,10 @@ class OllamaAPIIntegrationTest {
@Order(5) @Order(5)
void testEmbeddings() throws Exception { void testEmbeddings() throws Exception {
api.pullModel(EMBEDDING_MODEL); api.pullModel(EMBEDDING_MODEL);
OllamaEmbedResponseModel embeddings = api.embed(EMBEDDING_MODEL, OllamaEmbedRequestModel m = new OllamaEmbedRequestModel();
Arrays.asList("Why is the sky blue?", "Why is the grass green?")); m.setModel(EMBEDDING_MODEL);
m.setInput(Arrays.asList("Why is the sky blue?", "Why is the grass green?"));
OllamaEmbedResponseModel embeddings = api.embed(m);
assertNotNull(embeddings, "Embeddings should not be null"); assertNotNull(embeddings, "Embeddings should not be null");
assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty"); assertFalse(embeddings.getEmbeddings().isEmpty(), "Embeddings should not be empty");
} }
@ -228,7 +221,7 @@ class OllamaAPIIntegrationTest {
requestModel = builder.withMessages(requestModel.getMessages()) requestModel = builder.withMessages(requestModel.getMessages())
.withMessage(OllamaChatMessageRole.USER, "Give me a cool name") .withMessage(OllamaChatMessageRole.USER, "Give me a cool name")
.withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build(); .withOptions(new OptionsBuilder().setTemperature(0.5f).build()).build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel, null);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
@ -249,7 +242,7 @@ class OllamaAPIIntegrationTest {
expectedResponse)).withMessage(OllamaChatMessageRole.USER, "Who are you?") expectedResponse)).withMessage(OllamaChatMessageRole.USER, "Who are you?")
.withOptions(new OptionsBuilder().setTemperature(0.0f).build()).build(); .withOptions(new OptionsBuilder().setTemperature(0.0f).build()).build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel, null);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
@ -270,7 +263,7 @@ class OllamaAPIIntegrationTest {
.build(); .build();
// Start conversation with model // Start conversation with model
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel, null);
assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")), assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("2")),
"Expected chat history to contain '2'"); "Expected chat history to contain '2'");
@ -279,7 +272,7 @@ class OllamaAPIIntegrationTest {
.withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build(); .withMessage(OllamaChatMessageRole.USER, "And what is its squared value?").build();
// Continue conversation with model // Continue conversation with model
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel, null);
assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")), assertTrue(chatResult.getChatHistory().stream().anyMatch(chat -> chat.getContent().contains("4")),
"Expected chat history to contain '4'"); "Expected chat history to contain '4'");
@ -289,7 +282,7 @@ class OllamaAPIIntegrationTest {
"What is the largest value between 2, 4 and 6?").build(); "What is the largest value between 2, 4 and 6?").build();
// Continue conversation with the model for the third question // Continue conversation with the model for the third question
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel, null);
// verify the result // verify the result
assertNotNull(chatResult, "Chat result should not be null"); assertNotNull(chatResult, "Chat result should not be null");
@ -315,7 +308,7 @@ class OllamaAPIIntegrationTest {
"Give me the ID and address of the employee Rahul Kumar.").build(); "Give me the ID and address of the employee Rahul Kumar.").build();
requestModel.setOptions(new OptionsBuilder().setTemperature(0.9f).build().getOptionsMap()); requestModel.setOptions(new OptionsBuilder().setTemperature(0.9f).build().getOptionsMap());
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel, null);
assertNotNull(chatResult, "chatResult should not be null"); assertNotNull(chatResult, "chatResult should not be null");
assertNotNull(chatResult.getResponseModel(), "Response model should not be null"); assertNotNull(chatResult.getResponseModel(), "Response model should not be null");
@ -357,7 +350,7 @@ class OllamaAPIIntegrationTest {
.build(); .build();
requestModel.setOptions(new OptionsBuilder().setTemperature(0.9f).build().getOptionsMap()); requestModel.setOptions(new OptionsBuilder().setTemperature(0.9f).build().getOptionsMap());
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel, null);
assertNotNull(chatResult, "chatResult should not be null"); assertNotNull(chatResult, "chatResult should not be null");
assertNotNull(chatResult.getResponseModel(), "Response model should not be null"); assertNotNull(chatResult.getResponseModel(), "Response model should not be null");
@ -405,11 +398,11 @@ class OllamaAPIIntegrationTest {
.withKeepAlive("0m").withOptions(new OptionsBuilder().setTemperature(0.9f).build()) .withKeepAlive("0m").withOptions(new OptionsBuilder().setTemperature(0.9f).build())
.build(); .build();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> { OllamaChatResult chatResult = api.chat(requestModel, new OllamaChatStreamObserver((s) -> {
LOG.info(s.toUpperCase()); LOG.info(s.toUpperCase());
}, (s) -> { }, (s) -> {
LOG.info(s.toLowerCase()); LOG.info(s.toLowerCase());
}); }));
assertNotNull(chatResult, "chatResult should not be null"); assertNotNull(chatResult, "chatResult should not be null");
assertNotNull(chatResult.getResponseModel(), "Response model should not be null"); assertNotNull(chatResult.getResponseModel(), "Response model should not be null");
@ -447,7 +440,7 @@ class OllamaAPIIntegrationTest {
"Compute the most important constant in the world using 5 digits") "Compute the most important constant in the world using 5 digits")
.build(); .build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel, null);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
@ -480,7 +473,7 @@ class OllamaAPIIntegrationTest {
"Greet Rahul with a lot of hearts and respond to me with count of emojis that have been in used in the greeting") "Greet Rahul with a lot of hearts and respond to me with count of emojis that have been in used in the greeting")
.build(); .build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel, null);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
@ -515,13 +508,11 @@ class OllamaAPIIntegrationTest {
requestModel.setThink(false); requestModel.setThink(false);
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> { OllamaChatResult chatResult = api.chat(requestModel, new OllamaChatStreamObserver((s) -> {
LOG.info(s.toUpperCase()); LOG.info(s.toUpperCase());
sb.append(s);
}, (s) -> { }, (s) -> {
LOG.info(s.toLowerCase()); LOG.info(s.toLowerCase());
sb.append(s); }));
});
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage()); assertNotNull(chatResult.getResponseModel().getMessage());
@ -540,13 +531,11 @@ class OllamaAPIIntegrationTest {
.withThinking(true).withKeepAlive("0m").build(); .withThinking(true).withKeepAlive("0m").build();
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
OllamaChatResult chatResult = api.chat(requestModel, (s) -> { OllamaChatResult chatResult = api.chat(requestModel, new OllamaChatStreamObserver((s) -> {
sb.append(s);
LOG.info(s.toUpperCase()); LOG.info(s.toUpperCase());
}, (s) -> { }, (s) -> {
sb.append(s);
LOG.info(s.toLowerCase()); LOG.info(s.toLowerCase());
}); }));
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
@ -569,7 +558,7 @@ class OllamaAPIIntegrationTest {
.build(); .build();
api.registerAnnotatedTools(new OllamaAPIIntegrationTest()); api.registerAnnotatedTools(new OllamaAPIIntegrationTest());
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel, null);
assertNotNull(chatResult); assertNotNull(chatResult);
} }
@ -583,7 +572,7 @@ class OllamaAPIIntegrationTest {
"What's in the picture?", Collections.emptyList(), "What's in the picture?", Collections.emptyList(),
List.of(getImageFileFromClasspath("emoji-smile.jpeg"))).build(); List.of(getImageFileFromClasspath("emoji-smile.jpeg"))).build();
OllamaChatResult chatResult = api.chat(requestModel); OllamaChatResult chatResult = api.chat(requestModel, null);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
builder.reset(); builder.reset();
@ -591,7 +580,7 @@ class OllamaAPIIntegrationTest {
requestModel = builder.withMessages(chatResult.getChatHistory()) requestModel = builder.withMessages(chatResult.getChatHistory())
.withMessage(OllamaChatMessageRole.USER, "What's the color?").build(); .withMessage(OllamaChatMessageRole.USER, "What's the color?").build();
chatResult = api.chat(requestModel); chatResult = api.chat(requestModel, null);
assertNotNull(chatResult); assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel()); assertNotNull(chatResult.getResponseModel());
} }

View File

@ -4,7 +4,8 @@ import io.github.ollama4j.models.request.BasicAuth;
import io.github.ollama4j.models.request.BearerAuth; import io.github.ollama4j.models.request.BearerAuth;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
class TestAuth { class TestAuth {

View File

@ -96,9 +96,12 @@ class TestMockedAPIs {
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
try { try {
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>()); OllamaEmbedRequestModel m = new OllamaEmbedRequestModel();
ollamaAPI.generateEmbeddings(model, prompt); m.setModel(model);
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt); m.setInput(List.of(prompt));
when(ollamaAPI.embed(m)).thenReturn(new OllamaEmbedResponseModel());
ollamaAPI.embed(m);
verify(ollamaAPI, times(1)).embed(m);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -110,9 +113,10 @@ class TestMockedAPIs {
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
List<String> inputs = List.of("some prompt text"); List<String> inputs = List.of("some prompt text");
try { try {
when(ollamaAPI.embed(model, inputs)).thenReturn(new OllamaEmbedResponseModel()); OllamaEmbedRequestModel m = new OllamaEmbedRequestModel(model, inputs);
ollamaAPI.embed(model, inputs); when(ollamaAPI.embed(m)).thenReturn(new OllamaEmbedResponseModel());
verify(ollamaAPI, times(1)).embed(model, inputs); ollamaAPI.embed(m);
verify(ollamaAPI, times(1)).embed(m);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }

View File

@ -5,7 +5,8 @@ import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import org.json.JSONObject; import org.json.JSONObject;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
class TestOllamaChatMessage { class TestOllamaChatMessage {

View File

@ -7,7 +7,7 @@ import org.junit.jupiter.api.Test;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertTrue;
class TestToolsPromptBuilder { class TestToolsPromptBuilder {