Enhance OllamaAPI and OllamaResult for improved model pulling and structured responses

- Added a retry mechanism in OllamaAPI for model pulling, allowing configurable retries.
- Introduced new methods in OllamaResult for structured response handling, including parsing JSON responses into a Map or specific class types.
- Updated integration tests to validate the new functionality and ensure robust testing of model interactions.
- Improved code formatting and consistency across the OllamaAPI and integration test classes.
This commit is contained in:
Amith Koujalgi 2025-03-24 21:40:20 +05:30
parent 1bda78e35b
commit bc2a931586
No known key found for this signature in database
GPG Key ID: 3F065E7150B71F9D
4 changed files with 813 additions and 617 deletions

View File

@ -51,7 +51,7 @@ import java.util.stream.Collectors;
/**
* The base Ollama API class.
*/
@SuppressWarnings({"DuplicatedCode", "resource"})
@SuppressWarnings({ "DuplicatedCode", "resource" })
public class OllamaAPI {
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
@ -74,6 +74,12 @@ public class OllamaAPI {
private Auth auth;
private int numberOfRetriesForModelPull = 0;
public void setNumberOfRetriesForModelPull(int numberOfRetriesForModelPull) {
this.numberOfRetriesForModelPull = numberOfRetriesForModelPull;
}
private final ToolRegistry toolRegistry = new ToolRegistry();
/**
@ -209,7 +215,7 @@ public class OllamaAPI {
* tags, tag count, and the time when model was updated.
*
* @return A list of {@link LibraryModel} objects representing the models
* available in the Ollama library.
* available in the Ollama library.
* @throws OllamaBaseException If the HTTP request fails or the response is not
* successful (non-200 status code).
* @throws IOException If an I/O error occurs during the HTTP request
@ -275,7 +281,7 @@ public class OllamaAPI {
* of the library model
* for which the tags need to be fetched.
* @return a list of {@link LibraryModelTag} objects containing the extracted
* tags and their associated metadata.
* tags and their associated metadata.
* @throws OllamaBaseException if the HTTP response status code indicates an
* error (i.e., not 200 OK),
* or if there is any other issue during the
@ -342,7 +348,7 @@ public class OllamaAPI {
* @param modelName The name of the model to search for in the library.
* @param tag The tag name to search for within the specified model.
* @return The {@link LibraryModelTag} associated with the specified model and
* tag.
* tag.
* @throws OllamaBaseException If there is a problem with the Ollama library
* operations.
* @throws IOException If an I/O error occurs during the operation.
@ -376,6 +382,26 @@ public class OllamaAPI {
*/
public void pullModel(String modelName)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
if (numberOfRetriesForModelPull == 0) {
this.doPullModel(modelName);
} else {
int numberOfRetries = 0;
while (numberOfRetries < numberOfRetriesForModelPull) {
try {
this.doPullModel(modelName);
return;
} catch (OllamaBaseException e) {
logger.error("Failed to pull model " + modelName + ", retrying...");
numberOfRetries++;
}
}
throw new OllamaBaseException(
"Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries");
}
}
private void doPullModel(String modelName)
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
String url = this.host + "/api/pull";
String jsonData = new ModelRequest(modelName).toString();
HttpRequest request = getRequestBuilderDefault(new URI(url))
@ -729,7 +755,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generate(String model, String prompt, boolean raw, Options options,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setOptions(options.getOptionsMap());
@ -742,8 +768,10 @@ public class OllamaAPI {
* @param model The name or identifier of the AI model to use for generating
* the response.
* @param prompt The input text or prompt to provide to the AI model.
* @param format A map containing the format specification for the structured output.
* @return An instance of {@link OllamaResult} containing the structured response.
* @param format A map containing the format specification for the structured
* output.
* @return An instance of {@link OllamaResult} containing the structured
* response.
* @throws OllamaBaseException if the response indicates an error status.
* @throws IOException if an I/O error occurs during the HTTP request.
* @throws InterruptedException if the operation is interrupted.
@ -771,7 +799,11 @@ public class OllamaAPI {
String responseBody = response.body();
if (statusCode == 200) {
return Utils.getObjectMapper().readValue(responseBody, OllamaResult.class);
OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody,
OllamaStructuredResult.class);
OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(),
structuredResult.getResponseTime(), statusCode);
return ollamaResult;
} else {
throw new OllamaBaseException(statusCode + " - " + responseBody);
}
@ -813,8 +845,8 @@ public class OllamaAPI {
* @param options Additional options or configurations to use when generating
* the response.
* @return {@link OllamaToolsResult} An OllamaToolsResult object containing the
* response from the AI model and the results of invoking the tools on
* that output.
* response from the AI model and the results of invoking the tools on
* that output.
* @throws OllamaBaseException if the response indicates an error status
* @throws IOException if an I/O error occurs during the HTTP request
* @throws InterruptedException if the operation is interrupted
@ -906,7 +938,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted
*/
public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
List<String> images = new ArrayList<>();
for (File imageFile : imageFiles) {
images.add(encodeFileToBase64(imageFile));
@ -953,7 +985,7 @@ public class OllamaAPI {
* @throws URISyntaxException if the URI for the request is malformed
*/
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options,
OllamaStreamHandler streamHandler)
OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
List<String> images = new ArrayList<>();
for (String imageURL : imageURLs) {
@ -988,7 +1020,7 @@ public class OllamaAPI {
* @param model the ollama model to ask the question to
* @param messages chat history / message stack to send to the model
* @return {@link OllamaChatResult} containing the api response and the message
* history including the newly aqcuired assistant response.
* history including the newly aqcuired assistant response.
* @throws OllamaBaseException any response code than 200 has been returned
* @throws IOException in case the responseStream can not be read
* @throws InterruptedException in case the server is not reachable or network
@ -1171,7 +1203,7 @@ public class OllamaAPI {
registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
}
} catch (InstantiationException | NoSuchMethodException | IllegalAccessException
| InvocationTargetException e) {
| InvocationTargetException e) {
throw new RuntimeException(e);
}
}
@ -1308,7 +1340,7 @@ public class OllamaAPI {
* @throws InterruptedException if the thread is interrupted during the request.
*/
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds,
verbose);
OllamaResult result;

View File

@ -1,19 +1,26 @@
package io.github.ollama4j.models.response;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import lombok.Data;
import lombok.Getter;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
import java.util.HashMap;
import java.util.Map;
/** The type Ollama result. */
@Getter
@SuppressWarnings("unused")
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaResult {
/**
* -- GETTER --
* Get the completion/response text
* Get the completion/response text
*
* @return String completion/response text
*/
@ -21,7 +28,7 @@ public class OllamaResult {
/**
* -- GETTER --
* Get the response status code.
* Get the response status code.
*
* @return int - response status code
*/
@ -29,7 +36,7 @@ public class OllamaResult {
/**
* -- GETTER --
* Get the response time in milliseconds.
* Get the response time in milliseconds.
*
* @return long - response time in milliseconds
*/
@ -44,9 +51,68 @@ public class OllamaResult {
@Override
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
Map<String, Object> responseMap = new HashMap<>();
responseMap.put("response", this.response);
responseMap.put("httpStatusCode", this.httpStatusCode);
responseMap.put("responseTime", this.responseTime);
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* Get the structured response if the response is a JSON object.
*
* @return Map - structured response
* @throws IllegalArgumentException if the response is not a valid JSON object
*/
public Map<String, Object> getStructuredResponse() {
String responseStr = this.getResponse();
if (responseStr == null || responseStr.trim().isEmpty()) {
throw new IllegalArgumentException("Response is empty or null");
}
try {
// Check if the response is a valid JSON
if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) ||
(!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) {
throw new IllegalArgumentException("Response is not a valid JSON object");
}
Map<String, Object> response = getObjectMapper().readValue(responseStr,
new TypeReference<Map<String, Object>>() {
});
return response;
} catch (JsonProcessingException e) {
throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e);
}
}
/**
* Get the structured response mapped to a specific class type.
*
* @param <T> The type of class to map the response to
* @param clazz The class to map the response to
* @return An instance of the specified class with the response data
* @throws IllegalArgumentException if the response is not a valid JSON or is empty
* @throws RuntimeException if there is an error mapping the response
*/
public <T> T as(Class<T> clazz) {
String responseStr = this.getResponse();
if (responseStr == null || responseStr.trim().isEmpty()) {
throw new IllegalArgumentException("Response is empty or null");
}
try {
// Check if the response is a valid JSON
if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) ||
(!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) {
throw new IllegalArgumentException("Response is not a valid JSON object");
}
return getObjectMapper().readValue(responseStr, clazz);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e);
}
}
}

View File

@ -0,0 +1,77 @@
package io.github.ollama4j.models.response;
import static io.github.ollama4j.utils.Utils.getObjectMapper;
import java.util.Map;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;
@Getter
@SuppressWarnings("unused")
@Data
@NoArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
public class OllamaStructuredResult {
private String response;
private int httpStatusCode;
private long responseTime = 0;
private String model;
public OllamaStructuredResult(String response, long responseTime, int httpStatusCode) {
this.response = response;
this.responseTime = responseTime;
this.httpStatusCode = httpStatusCode;
}
@Override
public String toString() {
try {
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* Get the structured response if the response is a JSON object.
*
* @return Map - structured response
*/
public Map<String, Object> getStructuredResponse() {
try {
Map<String, Object> response = getObjectMapper().readValue(this.getResponse(),
new TypeReference<Map<String, Object>>() {
});
return response;
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* Get the structured response mapped to a specific class type.
*
* @param <T> The type of class to map the response to
* @param clazz The class to map the response to
* @return An instance of the specified class with the response data
* @throws RuntimeException if there is an error mapping the response
*/
public <T> T getStructuredResponse(Class<T> clazz) {
try {
return getObjectMapper().readValue(this.getResponse(), clazz);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}