From c8c30d703be41f2685854f32508da644b111320d Mon Sep 17 00:00:00 2001 From: amithkoujalgi Date: Sat, 8 Mar 2025 15:46:43 +0530 Subject: [PATCH] Refactor code to enhance robustness and clarity Refactored `OllamaChatMessageRole` to simplify custom role creation, guarding against nulls in `OllamaToolsResult`, and made `OllamaChatResult` properties immutable. Improved error handling in `OllamaAPI`, added verbose logs, and ensured safer JSON parsing for tool responses. Introduced `@JsonIgnoreProperties` for better deserialization support. --- .../java/io/github/ollama4j/OllamaAPI.java | 43 ++++++++++++++----- .../models/chat/OllamaChatMessageRole.java | 6 +-- .../models/chat/OllamaChatResult.java | 5 +-- .../tools/OllamaToolCallsFunction.java | 4 +- .../ollama4j/tools/OllamaToolsResult.java | 3 ++ 5 files changed, 44 insertions(+), 17 deletions(-) diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index 7d8385f..76af4c8 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -1,14 +1,17 @@ package io.github.ollama4j; +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.RoleNotFoundException; import io.github.ollama4j.exceptions.ToolInvocationException; import io.github.ollama4j.exceptions.ToolNotFoundException; import io.github.ollama4j.models.chat.*; import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel; +import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel; import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel; -import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel; import io.github.ollama4j.models.generate.OllamaGenerateRequest; import io.github.ollama4j.models.generate.OllamaStreamHandler; import io.github.ollama4j.models.generate.OllamaTokenHandler; @@ -22,6 +25,12 @@ import io.github.ollama4j.tools.annotations.ToolSpec; import io.github.ollama4j.utils.Options; import io.github.ollama4j.utils.Utils; import lombok.Setter; +import org.jsoup.Jsoup; +import org.jsoup.nodes.Document; +import org.jsoup.nodes.Element; +import org.jsoup.select.Elements; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.*; import java.lang.reflect.InvocationTargetException; @@ -39,13 +48,6 @@ import java.time.Duration; import java.util.*; import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.jsoup.Jsoup; -import org.jsoup.nodes.Document; -import org.jsoup.nodes.Element; -import org.jsoup.select.Elements; - /** * The base Ollama API class. */ @@ -92,6 +94,9 @@ public class OllamaAPI { } else { this.host = host; } + if (this.verbose) { + logger.info("Ollama API initialized with host: " + this.host); + } } /** @@ -658,7 +663,22 @@ public class OllamaAPI { toolsResponse = toolsResponse.replace("[TOOL_CALLS]", ""); } - List toolFunctionCallSpecs = Utils.getObjectMapper().readValue(toolsResponse, Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class)); + List toolFunctionCallSpecs = new ArrayList<>(); + ObjectMapper objectMapper = Utils.getObjectMapper(); + + if (!toolsResponse.isEmpty()) { + try { + // Try to parse the string to see if it's a valid JSON + JsonNode jsonNode = objectMapper.readTree(toolsResponse); + } catch (JsonParseException e) { + logger.warn("Response from model does not contain any tool calls. Returning the response as is."); + return toolResult; + } + toolFunctionCallSpecs = objectMapper.readValue( + toolsResponse, + objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class) + ); + } for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) { toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec)); } @@ -881,6 +901,9 @@ public class OllamaAPI { */ public void registerTool(Tools.ToolSpecification toolSpecification) { toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification); + if (this.verbose) { + logger.debug("Registered tool: {}", toolSpecification.getFunctionName()); + } } /** @@ -1093,7 +1116,7 @@ public class OllamaAPI { logger.debug("Invoking function {} with arguments {}", methodName, arguments); } if (function == null) { - throw new ToolNotFoundException("No such tool: " + methodName); + throw new ToolNotFoundException("No such tool: " + methodName + ". Please register the tool before invoking it."); } return function.apply(arguments); } catch (Exception e) { diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessageRole.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessageRole.java index 4d00bc5..37d9d5c 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessageRole.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatMessageRole.java @@ -28,9 +28,9 @@ public class OllamaChatMessageRole { } public static OllamaChatMessageRole newCustomRole(String roleName) { - OllamaChatMessageRole customRole = new OllamaChatMessageRole(roleName); - roles.add(customRole); - return customRole; +// OllamaChatMessageRole customRole = new OllamaChatMessageRole(roleName); +// roles.add(customRole); + return new OllamaChatMessageRole(roleName); } public static List getRoles() { diff --git a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java index bf7eaea..f8ebb05 100644 --- a/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java +++ b/src/main/java/io/github/ollama4j/models/chat/OllamaChatResult.java @@ -14,10 +14,9 @@ import static io.github.ollama4j.utils.Utils.getObjectMapper; @Getter public class OllamaChatResult { + private final List chatHistory; - private List chatHistory; - - private OllamaChatResponseModel responseModel; + private final OllamaChatResponseModel responseModel; public OllamaChatResult(OllamaChatResponseModel responseModel, List chatHistory) { this.chatHistory = chatHistory; diff --git a/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java b/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java index 4be7194..c0192e2 100644 --- a/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java +++ b/src/main/java/io/github/ollama4j/tools/OllamaToolCallsFunction.java @@ -1,5 +1,6 @@ package io.github.ollama4j.tools; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; @@ -9,8 +10,9 @@ import java.util.Map; @Data @NoArgsConstructor @AllArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) public class OllamaToolCallsFunction { private String name; private Map arguments; -} +} \ No newline at end of file diff --git a/src/main/java/io/github/ollama4j/tools/OllamaToolsResult.java b/src/main/java/io/github/ollama4j/tools/OllamaToolsResult.java index c855bd2..35fada3 100644 --- a/src/main/java/io/github/ollama4j/tools/OllamaToolsResult.java +++ b/src/main/java/io/github/ollama4j/tools/OllamaToolsResult.java @@ -18,6 +18,9 @@ public class OllamaToolsResult { public List getToolResults() { List results = new ArrayList<>(); + if (this.toolResults == null) { + return results; + } for (Map.Entry r : this.toolResults.entrySet()) { results.add(new ToolResult(r.getKey().getName(), r.getKey().getArguments(), r.getValue())); }