mirror of
				https://github.com/amithkoujalgi/ollama4j.git
				synced 2025-10-31 16:40:41 +01:00 
			
		
		
		
	fix: handle ollama error responses
fixes: #138 - Added error field to ModelPullResponse - Enhanced error handling in doPullModel to check for errors in response body and throw OllamaBaseException with the specific error message
This commit is contained in:
		| @@ -2,7 +2,7 @@ repos: | ||||
|  | ||||
|   # pre-commit hooks | ||||
|   - repo: https://github.com/pre-commit/pre-commit-hooks | ||||
|     rev: "v5.0.0" | ||||
|     rev: "v6.0.0" | ||||
|     hooks: | ||||
|       - id: no-commit-to-branch | ||||
|         args: ['--branch', 'main'] | ||||
| @@ -21,7 +21,7 @@ repos: | ||||
|  | ||||
|   # for commit message formatting | ||||
|   - repo: https://github.com/commitizen-tools/commitizen | ||||
|     rev: v4.4.1 | ||||
|     rev: v4.8.3 | ||||
|     hooks: | ||||
|       - id: commitizen | ||||
|         stages: [commit-msg] | ||||
|   | ||||
| @@ -1,8 +1,8 @@ | ||||
| { | ||||
|   "label": "APIs - Extras", | ||||
|   "position": 4, | ||||
|   "link": { | ||||
|     "type": "generated-index", | ||||
|     "description": "Details of APIs to handle bunch of extra stuff." | ||||
|   } | ||||
|     "label": "APIs - Extras", | ||||
|     "position": 4, | ||||
|     "link": { | ||||
|         "type": "generated-index", | ||||
|         "description": "Details of APIs to handle bunch of extra stuff." | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -13,9 +13,9 @@ public class Main { | ||||
|  | ||||
|     public static void main(String[] args) { | ||||
|         String host = "http://localhost:11434/"; | ||||
|          | ||||
|  | ||||
|         OllamaAPI ollamaAPI = new OllamaAPI(host); | ||||
|          | ||||
|  | ||||
|         ollamaAPI.ping(); | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -1,8 +1,8 @@ | ||||
| { | ||||
|   "label": "APIs - Generate", | ||||
|   "position": 3, | ||||
|   "link": { | ||||
|     "type": "generated-index", | ||||
|     "description": "Details of APIs to interact with LLMs." | ||||
|   } | ||||
|     "label": "APIs - Generate", | ||||
|     "position": 3, | ||||
|     "link": { | ||||
|         "type": "generated-index", | ||||
|         "description": "Details of APIs to interact with LLMs." | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -1,8 +1,8 @@ | ||||
| { | ||||
|   "label": "APIs - Model Management", | ||||
|   "position": 2, | ||||
|   "link": { | ||||
|     "type": "generated-index", | ||||
|     "description": "Details of APIs to manage LLMs." | ||||
|   } | ||||
|     "label": "APIs - Model Management", | ||||
|     "position": 2, | ||||
|     "link": { | ||||
|         "type": "generated-index", | ||||
|         "description": "Details of APIs to manage LLMs." | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -420,16 +420,23 @@ public class OllamaAPI { | ||||
|             String line; | ||||
|             while ((line = reader.readLine()) != null) { | ||||
|                 ModelPullResponse modelPullResponse = Utils.getObjectMapper().readValue(line, ModelPullResponse.class); | ||||
|                 if (modelPullResponse != null && modelPullResponse.getStatus() != null) { | ||||
|                     if (verbose) { | ||||
|                         logger.info(modelName + ": " + modelPullResponse.getStatus()); | ||||
|                 if (modelPullResponse != null) { | ||||
|                     // Check for error in response body first | ||||
|                     if (modelPullResponse.getError() != null && !modelPullResponse.getError().trim().isEmpty()) { | ||||
|                         throw new OllamaBaseException("Model pull failed: " + modelPullResponse.getError()); | ||||
|                     } | ||||
|                     // Check if status is "success" and set success flag to true. | ||||
|                     if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) { | ||||
|                         success = true; | ||||
|  | ||||
|                     if (modelPullResponse.getStatus() != null) { | ||||
|                         if (verbose) { | ||||
|                             logger.info(modelName + ": " + modelPullResponse.getStatus()); | ||||
|                         } | ||||
|                         // Check if status is "success" and set success flag to true. | ||||
|                         if ("success".equalsIgnoreCase(modelPullResponse.getStatus())) { | ||||
|                             success = true; | ||||
|                         } | ||||
|                     } | ||||
|                 } else { | ||||
|                     logger.error("Received null or invalid status for model pull."); | ||||
|                     logger.error("Received null response for model pull."); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| @@ -755,7 +762,7 @@ public class OllamaAPI { | ||||
|      * @throws InterruptedException if the operation is interrupted | ||||
|      */ | ||||
|     public OllamaResult generate(String model, String prompt, boolean raw, Options options, | ||||
|             OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { | ||||
|                                  OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { | ||||
|         OllamaGenerateRequest ollamaRequestModel = new OllamaGenerateRequest(model, prompt); | ||||
|         ollamaRequestModel.setRaw(raw); | ||||
|         ollamaRequestModel.setOptions(options.getOptionsMap()); | ||||
| @@ -938,7 +945,7 @@ public class OllamaAPI { | ||||
|      * @throws InterruptedException if the operation is interrupted | ||||
|      */ | ||||
|     public OllamaResult generateWithImageFiles(String model, String prompt, List<File> imageFiles, Options options, | ||||
|             OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { | ||||
|                                                OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { | ||||
|         List<String> images = new ArrayList<>(); | ||||
|         for (File imageFile : imageFiles) { | ||||
|             images.add(encodeFileToBase64(imageFile)); | ||||
| @@ -985,7 +992,7 @@ public class OllamaAPI { | ||||
|      * @throws URISyntaxException   if the URI for the request is malformed | ||||
|      */ | ||||
|     public OllamaResult generateWithImageURLs(String model, String prompt, List<String> imageURLs, Options options, | ||||
|             OllamaStreamHandler streamHandler) | ||||
|                                               OllamaStreamHandler streamHandler) | ||||
|             throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { | ||||
|         List<String> images = new ArrayList<>(); | ||||
|         for (String imageURL : imageURLs) { | ||||
| @@ -1247,7 +1254,7 @@ public class OllamaAPI { | ||||
|                 registerAnnotatedTools(provider.getDeclaredConstructor().newInstance()); | ||||
|             } | ||||
|         } catch (InstantiationException | NoSuchMethodException | IllegalAccessException | ||||
|                 | InvocationTargetException e) { | ||||
|                  | InvocationTargetException e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| @@ -1384,7 +1391,7 @@ public class OllamaAPI { | ||||
|      * @throws InterruptedException if the thread is interrupted during the request. | ||||
|      */ | ||||
|     private OllamaResult generateSyncForOllamaRequestModel(OllamaGenerateRequest ollamaRequestModel, | ||||
|             OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { | ||||
|                                                            OllamaStreamHandler streamHandler) throws OllamaBaseException, IOException, InterruptedException { | ||||
|         OllamaGenerateEndpointCaller requestCaller = new OllamaGenerateEndpointCaller(host, auth, requestTimeoutSeconds, | ||||
|                 verbose); | ||||
|         OllamaResult result; | ||||
|   | ||||
| @@ -26,7 +26,7 @@ public class OllamaGenerateRequestBuilder { | ||||
|         request.setPrompt(prompt); | ||||
|         return this; | ||||
|     } | ||||
|      | ||||
|  | ||||
|     public OllamaGenerateRequestBuilder withGetJsonResponse(){ | ||||
|         this.request.setReturnFormatJson(true); | ||||
|         return this; | ||||
|   | ||||
| @@ -13,8 +13,8 @@ import lombok.Data; | ||||
| @Data | ||||
| @JsonInclude(JsonInclude.Include.NON_NULL) | ||||
| public abstract class OllamaCommonRequest { | ||||
|    | ||||
|   protected String model;   | ||||
|  | ||||
|   protected String model; | ||||
|   @JsonSerialize(using = BooleanToJsonFormatFlagSerializer.class) | ||||
|   @JsonProperty(value = "format") | ||||
|   protected Boolean returnFormatJson; | ||||
| @@ -24,7 +24,7 @@ public abstract class OllamaCommonRequest { | ||||
|   @JsonProperty(value = "keep_alive") | ||||
|   protected String keepAlive; | ||||
|  | ||||
|    | ||||
|  | ||||
|   public String toString() { | ||||
|     try { | ||||
|       return Utils.getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this); | ||||
|   | ||||
| @@ -10,4 +10,5 @@ public class ModelPullResponse { | ||||
|     private String digest; | ||||
|     private long total; | ||||
|     private long completed; | ||||
|     private String error; | ||||
| } | ||||
|   | ||||
| @@ -120,4 +120,3 @@ public class OllamaAsyncResultStreamer extends Thread { | ||||
|     } | ||||
|  | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -15,4 +15,4 @@ public class OllamaToolCallsFunction | ||||
| { | ||||
|     private String name; | ||||
|     private Map<String,Object> arguments; | ||||
| } | ||||
| } | ||||
|   | ||||
| @@ -13,4 +13,3 @@ public class ToolFunctionCallSpec { | ||||
|     private String name; | ||||
|     private Map<String, Object> arguments; | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -18,4 +18,4 @@ public class FileToBase64Serializer extends JsonSerializer<Collection<byte[]>> { | ||||
|         } | ||||
|         jsonGenerator.writeEndArray(); | ||||
|     } | ||||
| } | ||||
| } | ||||
|   | ||||
| @@ -10,10 +10,10 @@ import com.fasterxml.jackson.core.JsonProcessingException; | ||||
|  * Interface to represent a OllamaRequest as HTTP-Request Body via {@link BodyPublishers}. | ||||
|  */ | ||||
| public interface OllamaRequestBody { | ||||
|      | ||||
|  | ||||
|     /** | ||||
|      * Transforms the OllamaRequest Object to a JSON Object via Jackson. | ||||
|      *  | ||||
|      * | ||||
|      * @return JSON representation of a OllamaRequest | ||||
|      */ | ||||
|     @JsonIgnore | ||||
|   | ||||
| @@ -0,0 +1,247 @@ | ||||
| package io.github.ollama4j.unittests.jackson; | ||||
|  | ||||
| import io.github.ollama4j.models.response.ModelPullResponse; | ||||
| import org.junit.jupiter.api.Test; | ||||
|  | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
|  | ||||
| /** | ||||
|  * Test serialization and deserialization of ModelPullResponse, | ||||
|  * This test verifies that the ModelPullResponse class can properly parse | ||||
|  * error responses from Ollama server that return HTTP 200 with error messages | ||||
|  * in the JSON body. | ||||
|  */ | ||||
| public class TestModelPullResponseSerialization extends AbstractSerializationTest<ModelPullResponse> { | ||||
|  | ||||
|     /** | ||||
|      * Test the specific error case reported in GitHub issue #138. | ||||
|      * Ollama sometimes returns HTTP 200 with error details in JSON. | ||||
|      */ | ||||
|     @Test | ||||
|     public void testDeserializationWithErrorFromGitHubIssue138() { | ||||
|         // This is the exact error JSON from GitHub issue #138 | ||||
|         String errorJson = "{\"error\":\"pull model manifest: 412: \\n\\nThe model you are attempting to pull requires a newer version of Ollama.\\n\\nPlease download the latest version at:\\n\\n\\thttps://ollama.com/download\\n\\n\"}"; | ||||
|  | ||||
|         ModelPullResponse response = deserialize(errorJson, ModelPullResponse.class); | ||||
|  | ||||
|         assertNotNull(response); | ||||
|         assertNotNull(response.getError()); | ||||
|         assertTrue(response.getError().contains("newer version of Ollama")); | ||||
|         assertTrue(response.getError().contains("https://ollama.com/download")); | ||||
|         assertNull(response.getStatus()); | ||||
|         assertNull(response.getDigest()); | ||||
|         assertEquals(0, response.getTotal()); | ||||
|         assertEquals(0, response.getCompleted()); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Test deserialization of ModelPullResponse with only status field present. | ||||
|      * Verifies that the response can handle minimal JSON with just status information. | ||||
|      */ | ||||
|     @Test | ||||
|     public void testDeserializationWithStatusField() { | ||||
|         String statusJson = "{\"status\":\"pulling manifest\"}"; | ||||
|  | ||||
|         ModelPullResponse response = deserialize(statusJson, ModelPullResponse.class); | ||||
|  | ||||
|         assertNotNull(response); | ||||
|         assertEquals("pulling manifest", response.getStatus()); | ||||
|         assertNull(response.getError()); | ||||
|         assertNull(response.getDigest()); | ||||
|         assertEquals(0, response.getTotal()); | ||||
|         assertEquals(0, response.getCompleted()); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Test deserialization of ModelPullResponse with progress tracking fields. | ||||
|      * Verifies that status, digest, total, and completed fields are properly parsed | ||||
|      * when downloading/pulling model data. | ||||
|      */ | ||||
|     @Test | ||||
|     public void testDeserializationWithProgressFields() { | ||||
|         String progressJson = "{\"status\":\"pulling digestname\",\"digest\":\"sha256:abc123\",\"total\":2142590208,\"completed\":241970}"; | ||||
|  | ||||
|         ModelPullResponse response = deserialize(progressJson, ModelPullResponse.class); | ||||
|  | ||||
|         assertNotNull(response); | ||||
|         assertEquals("pulling digestname", response.getStatus()); | ||||
|         assertEquals("sha256:abc123", response.getDigest()); | ||||
|         assertEquals(2142590208L, response.getTotal()); | ||||
|         assertEquals(241970L, response.getCompleted()); | ||||
|         assertNull(response.getError()); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Test deserialization of ModelPullResponse with success status. | ||||
|      * Verifies that successful completion responses are properly handled. | ||||
|      */ | ||||
|     @Test | ||||
|     public void testDeserializationWithSuccessStatus() { | ||||
|         String successJson = "{\"status\":\"success\"}"; | ||||
|  | ||||
|         ModelPullResponse response = deserialize(successJson, ModelPullResponse.class); | ||||
|  | ||||
|         assertNotNull(response); | ||||
|         assertEquals("success", response.getStatus()); | ||||
|         assertNull(response.getError()); | ||||
|         assertNull(response.getDigest()); | ||||
|         assertEquals(0, response.getTotal()); | ||||
|         assertEquals(0, response.getCompleted()); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Test deserialization of ModelPullResponse with all possible fields populated. | ||||
|      * Verifies that complete JSON responses with all fields are handled correctly. | ||||
|      */ | ||||
|     @Test | ||||
|     public void testDeserializationWithAllFields() { | ||||
|         String completeJson = "{\"status\":\"downloading\",\"digest\":\"sha256:def456\",\"total\":1000000,\"completed\":500000,\"error\":null}"; | ||||
|  | ||||
|         ModelPullResponse response = deserialize(completeJson, ModelPullResponse.class); | ||||
|  | ||||
|         assertNotNull(response); | ||||
|         assertEquals("downloading", response.getStatus()); | ||||
|         assertEquals("sha256:def456", response.getDigest()); | ||||
|         assertEquals(1000000L, response.getTotal()); | ||||
|         assertEquals(500000L, response.getCompleted()); | ||||
|         assertNull(response.getError()); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Test deserialization of ModelPullResponse with unknown JSON fields. | ||||
|      * Verifies that unknown fields are ignored due to @JsonIgnoreProperties(ignoreUnknown = true) | ||||
|      * annotation without causing deserialization errors. | ||||
|      */ | ||||
|     @Test | ||||
|     public void testDeserializationWithUnknownFields() { | ||||
|         // Test that unknown fields are ignored due to @JsonIgnoreProperties(ignoreUnknown = true) | ||||
|         String jsonWithUnknownFields = "{\"status\":\"pulling\",\"unknown_field\":\"should_be_ignored\",\"error\":\"test error\",\"another_unknown\":123,\"nested_unknown\":{\"key\":\"value\"}}"; | ||||
|  | ||||
|         ModelPullResponse response = deserialize(jsonWithUnknownFields, ModelPullResponse.class); | ||||
|  | ||||
|         assertNotNull(response); | ||||
|         assertEquals("pulling", response.getStatus()); | ||||
|         assertEquals("test error", response.getError()); | ||||
|         assertNull(response.getDigest()); | ||||
|         assertEquals(0, response.getTotal()); | ||||
|         assertEquals(0, response.getCompleted()); | ||||
|         // Unknown fields should be ignored | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Test deserialization of ModelPullResponse with empty string error field. | ||||
|      * Verifies that empty error strings are preserved as empty strings, not converted to null. | ||||
|      */ | ||||
|     @Test | ||||
|     public void testEmptyErrorFieldIsNull() { | ||||
|         String emptyErrorJson = "{\"error\":\"\",\"status\":\"pulling manifest\"}"; | ||||
|  | ||||
|         ModelPullResponse response = deserialize(emptyErrorJson, ModelPullResponse.class); | ||||
|  | ||||
|         assertNotNull(response); | ||||
|         assertEquals("pulling manifest", response.getStatus()); | ||||
|         assertEquals("", response.getError()); // Empty string, not null | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Test deserialization of ModelPullResponse with whitespace-only error field. | ||||
|      * Verifies that whitespace characters in error fields are preserved during JSON parsing. | ||||
|      */ | ||||
|     @Test | ||||
|     public void testWhitespaceOnlyErrorField() { | ||||
|         String whitespaceErrorJson = "{\"error\":\"   \\n\\t  \",\"status\":\"pulling manifest\"}"; | ||||
|  | ||||
|         ModelPullResponse response = deserialize(whitespaceErrorJson, ModelPullResponse.class); | ||||
|  | ||||
|         assertNotNull(response); | ||||
|         assertEquals("pulling manifest", response.getStatus()); | ||||
|         assertEquals("   \n\t  ", response.getError()); // Whitespace preserved in JSON parsing | ||||
|         assertTrue(response.getError().trim().isEmpty()); // But trimmed version is empty | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Test serialization of ModelPullResponse with error field to ensure round-trip compatibility. | ||||
|      * Verifies that objects can be properly serialized to JSON format. | ||||
|      */ | ||||
|     @Test | ||||
|     public void testSerializationWithErrorField() { | ||||
|         ModelPullResponse response = new ModelPullResponse(); | ||||
|         response.setError("Test error message"); | ||||
|         response.setStatus("failed"); | ||||
|  | ||||
|         String jsonString = serialize(response); | ||||
|  | ||||
|         assertNotNull(jsonString); | ||||
|         assertTrue(jsonString.contains("\"error\":\"Test error message\"")); | ||||
|         assertTrue(jsonString.contains("\"status\":\"failed\"")); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Test round-trip serialization and deserialization of ModelPullResponse with error data. | ||||
|      * Verifies that objects maintain integrity through serialize -> deserialize cycle. | ||||
|      */ | ||||
|     @Test | ||||
|     public void testRoundTripSerializationWithError() { | ||||
|         ModelPullResponse original = new ModelPullResponse(); | ||||
|         original.setError("Round trip test error"); | ||||
|         original.setStatus("error"); | ||||
|  | ||||
|         String json = serialize(original); | ||||
|         ModelPullResponse deserialized = deserialize(json, ModelPullResponse.class); | ||||
|  | ||||
|         assertEqualsAfterUnmarshalling(deserialized, original); | ||||
|         assertEquals("Round trip test error", deserialized.getError()); | ||||
|         assertEquals("error", deserialized.getStatus()); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Test round-trip serialization and deserialization of ModelPullResponse with progress data. | ||||
|      * Verifies that progress tracking information is preserved through serialize -> deserialize cycle. | ||||
|      */ | ||||
|     @Test | ||||
|     public void testRoundTripSerializationWithProgress() { | ||||
|         ModelPullResponse original = new ModelPullResponse(); | ||||
|         original.setStatus("downloading"); | ||||
|         original.setDigest("sha256:roundtrip"); | ||||
|         original.setTotal(2000000L); | ||||
|         original.setCompleted(1500000L); | ||||
|  | ||||
|         String json = serialize(original); | ||||
|         ModelPullResponse deserialized = deserialize(json, ModelPullResponse.class); | ||||
|  | ||||
|         assertEqualsAfterUnmarshalling(deserialized, original); | ||||
|         assertEquals("downloading", deserialized.getStatus()); | ||||
|         assertEquals("sha256:roundtrip", deserialized.getDigest()); | ||||
|         assertEquals(2000000L, deserialized.getTotal()); | ||||
|         assertEquals(1500000L, deserialized.getCompleted()); | ||||
|         assertNull(deserialized.getError()); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Test that verifies the error handling logic that would be used in doPullModel method. | ||||
|      * This simulates the actual error detection logic. | ||||
|      */ | ||||
|     @Test | ||||
|     public void testErrorHandlingLogic() { | ||||
|         // Error case - should trigger error handling | ||||
|         String errorJson = "{\"error\":\"test error\"}"; | ||||
|         ModelPullResponse errorResponse = deserialize(errorJson, ModelPullResponse.class); | ||||
|  | ||||
|         assertTrue(errorResponse.getError() != null && !errorResponse.getError().trim().isEmpty(), | ||||
|                 "Error response should trigger error handling logic"); | ||||
|  | ||||
|         // Normal case - should not trigger error handling | ||||
|         String normalJson = "{\"status\":\"pulling\"}"; | ||||
|         ModelPullResponse normalResponse = deserialize(normalJson, ModelPullResponse.class); | ||||
|  | ||||
|         assertFalse(normalResponse.getError() != null && !normalResponse.getError().trim().isEmpty(), | ||||
|                 "Normal response should not trigger error handling logic"); | ||||
|  | ||||
|         // Empty error case - should not trigger error handling | ||||
|         String emptyErrorJson = "{\"error\":\"\",\"status\":\"pulling\"}"; | ||||
|         ModelPullResponse emptyErrorResponse = deserialize(emptyErrorJson, ModelPullResponse.class); | ||||
|  | ||||
|         assertFalse(emptyErrorResponse.getError() != null && !emptyErrorResponse.getError().trim().isEmpty(), | ||||
|                 "Empty error response should not trigger error handling logic"); | ||||
|     } | ||||
| } | ||||
| @@ -12,4 +12,4 @@ | ||||
|  | ||||
|     <logger name="org.apache" level="WARN"/> | ||||
|     <logger name="httpclient" level="WARN"/> | ||||
| </configuration> | ||||
| </configuration> | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Travis Lyons
					Travis Lyons