fix: handle ollama error responses

fixes: #138

- Added error field to ModelPullResponse
  - Enhanced error handling in doPullModel to check for errors in response body and throw OllamaBaseException with the specific error message
This commit is contained in:
Travis Lyons 2025-08-20 11:51:49 -04:00
parent 1e17e258b6
commit bae903f8ca
No known key found for this signature in database
16 changed files with 298 additions and 45 deletions

View File

@ -2,7 +2,7 @@ repos:
# pre-commit hooks # pre-commit hooks
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: "v5.0.0" rev: "v6.0.0"
hooks: hooks:
- id: no-commit-to-branch - id: no-commit-to-branch
args: ['--branch', 'main'] args: ['--branch', 'main']
@ -21,7 +21,7 @@ repos:
# for commit message formatting # for commit message formatting
- repo: https://github.com/commitizen-tools/commitizen - repo: https://github.com/commitizen-tools/commitizen
rev: v4.4.1 rev: v4.8.3
hooks: hooks:
- id: commitizen - id: commitizen
stages: [commit-msg] stages: [commit-msg]

View File

@ -1,8 +1,8 @@
{ {
"label": "APIs - Extras", "label": "APIs - Extras",
"position": 4, "position": 4,
"link": { "link": {
"type": "generated-index", "type": "generated-index",
"description": "Details of APIs to handle bunch of extra stuff." "description": "Details of APIs to handle bunch of extra stuff."
} }
} }

View File

@ -13,9 +13,9 @@ public class Main {
public static void main(String[] args) { public static void main(String[] args) {
String host = "http://localhost:11434/"; String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host); OllamaAPI ollamaAPI = new OllamaAPI(host);
ollamaAPI.ping(); ollamaAPI.ping();
} }
} }

View File

@ -1,8 +1,8 @@
{ {
"label": "APIs - Generate", "label": "APIs - Generate",
"position": 3, "position": 3,
"link": { "link": {
"type": "generated-index", "type": "generated-index",
"description": "Details of APIs to interact with LLMs." "description": "Details of APIs to interact with LLMs."
} }
} }

View File

@ -1,8 +1,8 @@
{ {
"label": "APIs - Model Management", "label": "APIs - Model Management",
"position": 2, "position": 2,
"link": { "link": {
"type": "generated-index", "type": "generated-index",
"description": "Details of APIs to manage LLMs." "description": "Details of APIs to manage LLMs."
} }
} }

View File

@ -420,16 +420,23 @@ public class OllamaAPI {
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class); ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class);
if (modelPullResponse != null && modelPullResponse.getStatus() != null) { if (modelPullResponse != null) {
if (verbose) { // Check for error in response body first
logger.info(modelName + ": " + modelPullResponse.getStatus()); if (modelPullResponse.getError() != null && !modelPullResponse.getError().trim().isEmpty()) {
throw new OllamaBaseException("Model pull failed: " + modelPullResponse.getError());
} }
// Check if status is "success" and set success flag to true.
if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) { if (modelPullResponse.getStatus() != null) {
success = true; if (verbose) {
logger.info(modelName + ": " + modelPullResponse.getStatus());
}
// Check if status is "success" and set success flag to true.
if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) {
success = true;
}
} }
} else { } else {
logger.error("Received null or invalid status for model pull."); logger.error("Received null response for model pull.");
} }
} }
} }
@ -755,7 +762,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generate(String model, String prompt, boolean raw, Options options, public OllamaResult generate(String model, String prompt, boolean raw, Options options,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt);
ollamaRequestModel.setRaw(raw); ollamaRequestModel.setRaw(raw);
ollamaRequestModel.setOptions(options.getOptionsMap()); ollamaRequestModel.setOptions(options.getOptionsMap());
@ -938,7 +945,7 @@ public class OllamaAPI {
* @throws InterruptedException if the operation is interrupted * @throws InterruptedException if the operation is interrupted
*/ */
public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options, public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
List<String> images = new ArrayList<>(); List<String> images = new ArrayList<>();
for (File imageFile : imageFiles) { for (File imageFile : imageFiles) {
images.add(encodeFileToBase64(imageFile)); images.add(encodeFileToBase64(imageFile));
@ -985,7 +992,7 @@ public class OllamaAPI {
* @throws URISyntaxException if the URI for the request is malformed * @throws URISyntaxException if the URI for the request is malformed
*/ */
public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options, public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options,
OllamaStreamHandler streamHandler) OllamaStreamHandler streamHandler)
throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { throws OllamaBaseException, IOException, InterruptedException, URISyntaxException {
List<String> images = new ArrayList<>(); List<String> images = new ArrayList<>();
for (String imageURL : imageURLs) { for (String imageURL : imageURLs) {
@ -1247,7 +1254,7 @@ public class OllamaAPI {
registerAnnotatedTools(provider.getDeclaredConstructor().newInstance()); registerAnnotatedTools(provider.getDeclaredConstructor().newInstance());
} }
} catch (InstantiationException | NoSuchMethodException | IllegalAccessException } catch (InstantiationException | NoSuchMethodException | IllegalAccessException
| InvocationTargetException e) { | InvocationTargetException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
@ -1384,7 +1391,7 @@ public class OllamaAPI {
* @throws InterruptedException if the thread is interrupted during the request. * @throws InterruptedException if the thread is interrupted during the request.
*/ */
private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel,
OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException {
OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds,
verbose); verbose);
OllamaResult result; OllamaResult result;

View File

@ -26,7 +26,7 @@ public class OllamaGenerateRequestBuilder {
request.setPrompt(prompt); request.setPrompt(prompt);
return this; return this;
} }
public OllamaGenerateRequestBuilder withGetJsonResponse(){ public OllamaGenerateRequestBuilder withGetJsonResponse(){
this.request.setReturnFormatJson(true); this.request.setReturnFormatJson(true);
return this; return this;

View File

@ -13,8 +13,8 @@ import lombok.Data;
@Data @Data
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
public abstract class OllamaCommonRequest { public abstract class OllamaCommonRequest {
protected String model; protected String model;
@JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class) @JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class)
@JsonProperty(value = "format") @JsonProperty(value = "format")
protected Boolean returnFormatJson; protected Boolean returnFormatJson;
@ -24,7 +24,7 @@ public abstract class OllamaCommonRequest {
@JsonProperty(value = "keep_alive") @JsonProperty(value = "keep_alive")
protected String keepAlive; protected String keepAlive;
public String toString() { public String toString() {
try { try {
return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);

View File

@ -10,4 +10,5 @@ public class ModelPullResponse {
private String digest; private String digest;
private long total; private long total;
private long completed; private long completed;
private String error;
} }

View File

@ -120,4 +120,3 @@ public class OllamaAsyncResultStreamer extends Thread {
} }
} }

View File

@ -15,4 +15,4 @@ public class OllamaToolCallsFunction
{ {
private String name; private String name;
private Map<String,Object> arguments; private Map<String,Object> arguments;
} }

View File

@ -13,4 +13,3 @@ public class ToolFunctionCallSpec {
private String name; private String name;
private Map<String, Object> arguments; private Map<String, Object> arguments;
} }

View File

@ -18,4 +18,4 @@ public class FileToBase64Serializer extends JsonSerializer<Collection<byte[]>> {
} }
jsonGenerator.writeEndArray(); jsonGenerator.writeEndArray();
} }
} }

View File

@ -10,10 +10,10 @@ import com.fasterxml.jackson.core.JsonProcessingException;
* Interface to represent a OllamaRequest as HTTP-Request Body via {@link BodyPublishers}. * Interface to represent a OllamaRequest as HTTP-Request Body via {@link BodyPublishers}.
*/ */
public interface OllamaRequestBody { public interface OllamaRequestBody {
/** /**
* Transforms the OllamaRequest Object to a JSON Object via Jackson. * Transforms the OllamaRequest Object to a JSON Object via Jackson.
* *
* @return JSON representation of a OllamaRequest * @return JSON representation of a OllamaRequest
*/ */
@JsonIgnore @JsonIgnore

View File

@ -0,0 +1,247 @@
package io.github.ollama4j.unittests.jackson;
import io.github.ollama4j.models.response.ModelPullResponse;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
/**
* Test serialization and deserialization of ModelPullResponse,
* This test verifies that the ModelPullResponse class can properly parse
* error responses from Ollama server that return HTTP 200 with error messages
* in the JSON body.
*/
public class TestModelPullResponseSerialization extends AbstractSerializationTest<ModelPullResponse> {
/**
* Test the specific error case reported in GitHub issue #138.
* Ollama sometimes returns HTTP 200 with error details in JSON.
*/
@Test
public void testDeserializationWithErrorFromGitHubIssue138() {
// This is the exact error JSON from GitHub issue #138
String errorJson = "{\"error\":\"pull model manifest: 412: \\n\\nThe model you are attempting to pull requires a newer version of Ollama.\\n\\nPlease download the latest version at:\\n\\n\\thttps://ollama.com/download\\n\\n\"}";
ModelPullResponse response = deserialize(errorJson, ModelPullResponse.class);
assertNotNull(response);
assertNotNull(response.getError());
assertTrue(response.getError().contains("newer version of Ollama"));
assertTrue(response.getError().contains("https://ollama.com/download"));
assertNull(response.getStatus());
assertNull(response.getDigest());
assertEquals(0, response.getTotal());
assertEquals(0, response.getCompleted());
}
/**
* Test deserialization of ModelPullResponse with only status field present.
* Verifies that the response can handle minimal JSON with just status information.
*/
@Test
public void testDeserializationWithStatusField() {
String statusJson = "{\"status\":\"pulling manifest\"}";
ModelPullResponse response = deserialize(statusJson, ModelPullResponse.class);
assertNotNull(response);
assertEquals("pulling manifest", response.getStatus());
assertNull(response.getError());
assertNull(response.getDigest());
assertEquals(0, response.getTotal());
assertEquals(0, response.getCompleted());
}
/**
* Test deserialization of ModelPullResponse with progress tracking fields.
* Verifies that status, digest, total, and completed fields are properly parsed
* when downloading/pulling model data.
*/
@Test
public void testDeserializationWithProgressFields() {
String progressJson = "{\"status\":\"pulling digestname\",\"digest\":\"sha256:abc123\",\"total\":2142590208,\"completed\":241970}";
ModelPullResponse response = deserialize(progressJson, ModelPullResponse.class);
assertNotNull(response);
assertEquals("pulling digestname", response.getStatus());
assertEquals("sha256:abc123", response.getDigest());
assertEquals(2142590208L, response.getTotal());
assertEquals(241970L, response.getCompleted());
assertNull(response.getError());
}
/**
* Test deserialization of ModelPullResponse with success status.
* Verifies that successful completion responses are properly handled.
*/
@Test
public void testDeserializationWithSuccessStatus() {
String successJson = "{\"status\":\"success\"}";
ModelPullResponse response = deserialize(successJson, ModelPullResponse.class);
assertNotNull(response);
assertEquals("success", response.getStatus());
assertNull(response.getError());
assertNull(response.getDigest());
assertEquals(0, response.getTotal());
assertEquals(0, response.getCompleted());
}
/**
* Test deserialization of ModelPullResponse with all possible fields populated.
* Verifies that complete JSON responses with all fields are handled correctly.
*/
@Test
public void testDeserializationWithAllFields() {
String completeJson = "{\"status\":\"downloading\",\"digest\":\"sha256:def456\",\"total\":1000000,\"completed\":500000,\"error\":null}";
ModelPullResponse response = deserialize(completeJson, ModelPullResponse.class);
assertNotNull(response);
assertEquals("downloading", response.getStatus());
assertEquals("sha256:def456", response.getDigest());
assertEquals(1000000L, response.getTotal());
assertEquals(500000L, response.getCompleted());
assertNull(response.getError());
}
/**
* Test deserialization of ModelPullResponse with unknown JSON fields.
* Verifies that unknown fields are ignored due to @JsonIgnoreProperties(ignoreUnknown = true)
* annotation without causing deserialization errors.
*/
@Test
public void testDeserializationWithUnknownFields() {
// Test that unknown fields are ignored due to @JsonIgnoreProperties(ignoreUnknown = true)
String jsonWithUnknownFields = "{\"status\":\"pulling\",\"unknown_field\":\"should_be_ignored\",\"error\":\"test error\",\"another_unknown\":123,\"nested_unknown\":{\"key\":\"value\"}}";
ModelPullResponse response = deserialize(jsonWithUnknownFields, ModelPullResponse.class);
assertNotNull(response);
assertEquals("pulling", response.getStatus());
assertEquals("test error", response.getError());
assertNull(response.getDigest());
assertEquals(0, response.getTotal());
assertEquals(0, response.getCompleted());
// Unknown fields should be ignored
}
/**
* Test deserialization of ModelPullResponse with empty string error field.
* Verifies that empty error strings are preserved as empty strings, not converted to null.
*/
@Test
public void testEmptyErrorFieldIsNull() {
String emptyErrorJson = "{\"error\":\"\",\"status\":\"pulling manifest\"}";
ModelPullResponse response = deserialize(emptyErrorJson, ModelPullResponse.class);
assertNotNull(response);
assertEquals("pulling manifest", response.getStatus());
assertEquals("", response.getError()); // Empty string, not null
}
/**
* Test deserialization of ModelPullResponse with whitespace-only error field.
* Verifies that whitespace characters in error fields are preserved during JSON parsing.
*/
@Test
public void testWhitespaceOnlyErrorField() {
String whitespaceErrorJson = "{\"error\":\" \\n\\t \",\"status\":\"pulling manifest\"}";
ModelPullResponse response = deserialize(whitespaceErrorJson, ModelPullResponse.class);
assertNotNull(response);
assertEquals("pulling manifest", response.getStatus());
assertEquals(" \n\t ", response.getError()); // Whitespace preserved in JSON parsing
assertTrue(response.getError().trim().isEmpty()); // But trimmed version is empty
}
/**
* Test serialization of ModelPullResponse with error field to ensure round-trip compatibility.
* Verifies that objects can be properly serialized to JSON format.
*/
@Test
public void testSerializationWithErrorField() {
ModelPullResponse response = new ModelPullResponse();
response.setError("Test error message");
response.setStatus("failed");
String jsonString = serialize(response);
assertNotNull(jsonString);
assertTrue(jsonString.contains("\"error\":\"Test error message\""));
assertTrue(jsonString.contains("\"status\":\"failed\""));
}
/**
* Test round-trip serialization and deserialization of ModelPullResponse with error data.
* Verifies that objects maintain integrity through serialize -> deserialize cycle.
*/
@Test
public void testRoundTripSerializationWithError() {
ModelPullResponse original = new ModelPullResponse();
original.setError("Round trip test error");
original.setStatus("error");
String json = serialize(original);
ModelPullResponse deserialized = deserialize(json, ModelPullResponse.class);
assertEqualsAfterUnmarshalling(deserialized, original);
assertEquals("Round trip test error", deserialized.getError());
assertEquals("error", deserialized.getStatus());
}
/**
* Test round-trip serialization and deserialization of ModelPullResponse with progress data.
* Verifies that progress tracking information is preserved through serialize -> deserialize cycle.
*/
@Test
public void testRoundTripSerializationWithProgress() {
ModelPullResponse original = new ModelPullResponse();
original.setStatus("downloading");
original.setDigest("sha256:roundtrip");
original.setTotal(2000000L);
original.setCompleted(1500000L);
String json = serialize(original);
ModelPullResponse deserialized = deserialize(json, ModelPullResponse.class);
assertEqualsAfterUnmarshalling(deserialized, original);
assertEquals("downloading", deserialized.getStatus());
assertEquals("sha256:roundtrip", deserialized.getDigest());
assertEquals(2000000L, deserialized.getTotal());
assertEquals(1500000L, deserialized.getCompleted());
assertNull(deserialized.getError());
}
/**
* Test that verifies the error handling logic that would be used in doPullModel method.
* This simulates the actual error detection logic.
*/
@Test
public void testErrorHandlingLogic() {
// Error case - should trigger error handling
String errorJson = "{\"error\":\"test error\"}";
ModelPullResponse errorResponse = deserialize(errorJson, ModelPullResponse.class);
assertTrue(errorResponse.getError() != null && !errorResponse.getError().trim().isEmpty(),
"Error response should trigger error handling logic");
// Normal case - should not trigger error handling
String normalJson = "{\"status\":\"pulling\"}";
ModelPullResponse normalResponse = deserialize(normalJson, ModelPullResponse.class);
assertFalse(normalResponse.getError() != null && !normalResponse.getError().trim().isEmpty(),
"Normal response should not trigger error handling logic");
// Empty error case - should not trigger error handling
String emptyErrorJson = "{\"error\":\"\",\"status\":\"pulling\"}";
ModelPullResponse emptyErrorResponse = deserialize(emptyErrorJson, ModelPullResponse.class);
assertFalse(emptyErrorResponse.getError() != null && !emptyErrorResponse.getError().trim().isEmpty(),
"Empty error response should not trigger error handling logic");
}
}

View File

@ -12,4 +12,4 @@
<logger name="org.apache" level="WARN"/> <logger name="org.apache" level="WARN"/>
<logger name="httpclient" level="WARN"/> <logger name="httpclient" level="WARN"/>
</configuration> </configuration>