mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 03:47:13 +02:00
Enhance OllamaAPI and OllamaResult for improved model pulling and structured responses
- Added a retry mechanism in OllamaAPI for model pulling, allowing configurable retries. - Introduced new methods in OllamaResult for structured response handling, including parsing JSON responses into a Map or specific class types. - Updated integration tests to validate the new functionality and ensure robust testing of model interactions. - Improved code formatting and consistency across the OllamaAPI and integration test classes.
This commit is contained in:
parent
1bda78e35b
commit
bc2a931586
@ -51,7 +51,7 @@ import java.util.stream.Collectors;
|
||||
/**
|
||||
* The base Ollama API class.
|
||||
*/
|
||||
@SuppressWarnings({"DuplicatedCode", "resource"})
|
||||
@SuppressWarnings({ "DuplicatedCode", "resource" })
|
||||
public class OllamaAPI {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
|
||||
@ -74,6 +74,12 @@ public class OllamaAPI {
|
||||
|
||||
private Auth auth;
|
||||
|
||||
private int numberOfRetriesForModelPull = 0;
|
||||
|
||||
public void setNumberOfRetriesForModelPull(int numberOfRetriesForModelPull) {
|
||||
this.numberOfRetriesForModelPull = numberOfRetriesForModelPull;
|
||||
}
|
||||
|
||||
private final ToolRegistry toolRegistry = new ToolRegistry();
|
||||
|
||||
/**
|
||||
@ -209,7 +215,7 @@ public class OllamaAPI {
|
||||
* tags, tag count, and the time when model was updated.
|
||||
*
|
||||
* @return A list of {@link LibraryModel} objects representing the models
|
||||
* available in the Ollama library.
|
||||
* available in the Ollama library.
|
||||
* @throws OllamaBaseException If the HTTP request fails or the response is not
|
||||
* successful (non-200 status code).
|
||||
* @throws IOException If an I/O error occurs during the HTTP request
|
||||
@ -275,7 +281,7 @@ public class OllamaAPI {
|
||||
* of the library model
|
||||
* for which the tags need to be fetched.
|
||||
* @return a list of {@link LibraryModelTag} objects containing the extracted
|
||||
* tags and their associated metadata.
|
||||
* tags and their associated metadata.
|
||||
* @throws OllamaBaseException if the HTTP response status code indicates an
|
||||
* error (i.e., not 200 OK),
|
||||
* or if there is any other issue during the
|
||||
@ -342,7 +348,7 @@ public class OllamaAPI {
|
||||
* @param modelName The name of the model to search for in the library.
|
||||
* @param tag The tag name to search for within the specified model.
|
||||
* @return The {@link LibraryModelTag} associated with the specified model and
|
||||
* tag.
|
||||
* tag.
|
||||
* @throws OllamaBaseException If there is a problem with the Ollama library
|
||||
* operations.
|
||||
* @throws IOException If an I/O error occurs during the operation.
|
||||
@ -376,6 +382,26 @@ public class OllamaAPI {
|
||||
*/
|
||||
public void pullModel(String modelName)
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
if (numberOfRetriesForModelPull == 0) {
|
||||
this.doPullModel(modelName);
|
||||
} else {
|
||||
int numberOfRetries = 0;
|
||||
while (numberOfRetries < numberOfRetriesForModelPull) {
|
||||
try {
|
||||
this.doPullModel(modelName);
|
||||
return;
|
||||
} catch (OllamaBaseException e) {
|
||||
logger.error("Failed to pull model " + modelName + ", retrying...");
|
||||
numberOfRetries++;
|
||||
}
|
||||
}
|
||||
throw new OllamaBaseException(
|
||||
"Failed to pull model " + modelName + " after " + numberOfRetriesForModelPull + " retries");
|
||||
}
|
||||
}
|
||||
|
||||
private void doPullModel(String modelName)
|
||||
throws OllamaBaseException, IOException, URISyntaxException, InterruptedException {
|
||||
String url = this.host + "/api/pull";
|
||||
String jsonData = new ModelRequest(modelName).toString();
|
||||
HttpRequest request = getRequestBuilderDefault(new URI(url))
|
||||
@ -729,7 +755,7 @@ public class OllamaAPI {
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaResult generate(String model, String prompt, boolean raw, Options options,
|
||||
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
|
||||
ollamaRequestModel.setRaw(raw);
|
||||
ollamaRequestModel.setOptions(options.getOptionsMap());
|
||||
@ -742,8 +768,10 @@ public class OllamaAPI {
|
||||
* @param model The name or identifier of the AI model to use for generating
|
||||
* the response.
|
||||
* @param prompt The input text or prompt to provide to the AI model.
|
||||
* @param format A map containing the format specification for the structured output.
|
||||
* @return An instance of {@link OllamaResult} containing the structured response.
|
||||
* @param format A map containing the format specification for the structured
|
||||
* output.
|
||||
* @return An instance of {@link OllamaResult} containing the structured
|
||||
* response.
|
||||
* @throws OllamaBaseException if the response indicates an error status.
|
||||
* @throws IOException if an I/O error occurs during the HTTP request.
|
||||
* @throws InterruptedException if the operation is interrupted.
|
||||
@ -771,7 +799,11 @@ public class OllamaAPI {
|
||||
String responseBody = response.body();
|
||||
|
||||
if (statusCode == 200) {
|
||||
return Utils.getObjectMapper().readValue(responseBody, OllamaResult.class);
|
||||
OllamaStructuredResult structuredResult = Utils.getObjectMapper().readValue(responseBody,
|
||||
OllamaStructuredResult.class);
|
||||
OllamaResult ollamaResult = new OllamaResult(structuredResult.getResponse(),
|
||||
structuredResult.getResponseTime(), statusCode);
|
||||
return ollamaResult;
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseBody);
|
||||
}
|
||||
@ -813,8 +845,8 @@ public class OllamaAPI {
|
||||
* @param options Additional options or configurations to use when generating
|
||||
* the response.
|
||||
* @return {@link OllamaToolsResult} An OllamaToolsResult object containing the
|
||||
* response from the AI model and the results of invoking the tools on
|
||||
* that output.
|
||||
* response from the AI model and the results of invoking the tools on
|
||||
* that output.
|
||||
* @throws OllamaBaseException if the response indicates an error status
|
||||
* @throws IOException if an I/O error occurs during the HTTP request
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
@ -906,7 +938,7 @@ public class OllamaAPI {
|
||||
* @throws InterruptedException if the operation is interrupted
|
||||
*/
|
||||
public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options,
|
||||
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
List<String> images = new ArrayList<>();
|
||||
for (File imageFile : imageFiles) {
|
||||
images.add(encodeFileToBase64(imageFile));
|
||||
@ -953,7 +985,7 @@ public class OllamaAPI {
|
||||
* @throws URISyntaxException if the URI for the request is malformed
|
||||
*/
|
||||
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options,
|
||||
OllamaStreamHandler streamHandler)
|
||||
OllamaStreamHandler streamHandler)
|
||||
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
|
||||
List<String> images = new ArrayList<>();
|
||||
for (String imageURL : imageURLs) {
|
||||
@ -988,7 +1020,7 @@ public class OllamaAPI {
|
||||
* @param model the ollama model to ask the question to
|
||||
* @param messages chat history / message stack to send to the model
|
||||
* @return {@link OllamaChatResult} containing the api response and the message
|
||||
* history including the newly aqcuired assistant response.
|
||||
* history including the newly aqcuired assistant response.
|
||||
* @throws OllamaBaseException any response code than 200 has been returned
|
||||
* @throws IOException in case the responseStream can not be read
|
||||
* @throws InterruptedException in case the server is not reachable or network
|
||||
@ -1171,7 +1203,7 @@ public class OllamaAPI {
|
||||
registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
|
||||
}
|
||||
} catch (InstantiationException | NoSuchMethodException | IllegalAccessException
|
||||
| InvocationTargetException e) {
|
||||
| InvocationTargetException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
@ -1308,7 +1340,7 @@ public class OllamaAPI {
|
||||
* @throws InterruptedException if the thread is interrupted during the request.
|
||||
*/
|
||||
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel,
|
||||
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
|
||||
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds,
|
||||
verbose);
|
||||
OllamaResult result;
|
||||
|
@ -1,19 +1,26 @@
|
||||
package io.github.ollama4j.models.response;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/** The type Ollama result. */
|
||||
@Getter
|
||||
@SuppressWarnings("unused")
|
||||
@Data
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
public class OllamaResult {
|
||||
/**
|
||||
* -- GETTER --
|
||||
* Get the completion/response text
|
||||
* Get the completion/response text
|
||||
*
|
||||
* @return String completion/response text
|
||||
*/
|
||||
@ -21,7 +28,7 @@ public class OllamaResult {
|
||||
|
||||
/**
|
||||
* -- GETTER --
|
||||
* Get the response status code.
|
||||
* Get the response status code.
|
||||
*
|
||||
* @return int - response status code
|
||||
*/
|
||||
@ -29,7 +36,7 @@ public class OllamaResult {
|
||||
|
||||
/**
|
||||
* -- GETTER --
|
||||
* Get the response time in milliseconds.
|
||||
* Get the response time in milliseconds.
|
||||
*
|
||||
* @return long - response time in milliseconds
|
||||
*/
|
||||
@ -44,9 +51,68 @@ public class OllamaResult {
|
||||
@Override
|
||||
public String toString() {
|
||||
try {
|
||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||
Map<String, Object> responseMap = new HashMap<>();
|
||||
responseMap.put("response", this.response);
|
||||
responseMap.put("httpStatusCode", this.httpStatusCode);
|
||||
responseMap.put("responseTime", this.responseTime);
|
||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(responseMap);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the structured response if the response is a JSON object.
|
||||
*
|
||||
* @return Map - structured response
|
||||
* @throws IllegalArgumentException if the response is not a valid JSON object
|
||||
*/
|
||||
public Map<String, Object> getStructuredResponse() {
|
||||
String responseStr = this.getResponse();
|
||||
if (responseStr == null || responseStr.trim().isEmpty()) {
|
||||
throw new IllegalArgumentException("Response is empty or null");
|
||||
}
|
||||
|
||||
try {
|
||||
// Check if the response is a valid JSON
|
||||
if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) ||
|
||||
(!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) {
|
||||
throw new IllegalArgumentException("Response is not a valid JSON object");
|
||||
}
|
||||
|
||||
Map<String, Object> response = getObjectMapper().readValue(responseStr,
|
||||
new TypeReference<Map<String, Object>>() {
|
||||
});
|
||||
return response;
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the structured response mapped to a specific class type.
|
||||
*
|
||||
* @param <T> The type of class to map the response to
|
||||
* @param clazz The class to map the response to
|
||||
* @return An instance of the specified class with the response data
|
||||
* @throws IllegalArgumentException if the response is not a valid JSON or is empty
|
||||
* @throws RuntimeException if there is an error mapping the response
|
||||
*/
|
||||
public <T> T as(Class<T> clazz) {
|
||||
String responseStr = this.getResponse();
|
||||
if (responseStr == null || responseStr.trim().isEmpty()) {
|
||||
throw new IllegalArgumentException("Response is empty or null");
|
||||
}
|
||||
|
||||
try {
|
||||
// Check if the response is a valid JSON
|
||||
if ((!responseStr.trim().startsWith("{") && !responseStr.trim().startsWith("[")) ||
|
||||
(!responseStr.trim().endsWith("}") && !responseStr.trim().endsWith("]"))) {
|
||||
throw new IllegalArgumentException("Response is not a valid JSON object");
|
||||
}
|
||||
return getObjectMapper().readValue(responseStr, clazz);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalArgumentException("Failed to parse response as JSON: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,77 @@
|
||||
package io.github.ollama4j.models.response;
|
||||
|
||||
import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Getter
|
||||
@SuppressWarnings("unused")
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
public class OllamaStructuredResult {
|
||||
private String response;
|
||||
|
||||
private int httpStatusCode;
|
||||
|
||||
private long responseTime = 0;
|
||||
|
||||
private String model;
|
||||
|
||||
public OllamaStructuredResult(String response, long responseTime, int httpStatusCode) {
|
||||
this.response = response;
|
||||
this.responseTime = responseTime;
|
||||
this.httpStatusCode = httpStatusCode;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
try {
|
||||
return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the structured response if the response is a JSON object.
|
||||
*
|
||||
* @return Map - structured response
|
||||
*/
|
||||
public Map<String, Object> getStructuredResponse() {
|
||||
try {
|
||||
Map<String, Object> response = getObjectMapper().readValue(this.getResponse(),
|
||||
new TypeReference<Map<String, Object>>() {
|
||||
});
|
||||
return response;
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the structured response mapped to a specific class type.
|
||||
*
|
||||
* @param <T> The type of class to map the response to
|
||||
* @param clazz The class to map the response to
|
||||
* @return An instance of the specified class with the response data
|
||||
* @throws RuntimeException if there is an error mapping the response
|
||||
*/
|
||||
public <T> T getStructuredResponse(Class<T> clazz) {
|
||||
try {
|
||||
return getObjectMapper().readValue(this.getResponse(), clazz);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user