mirror of
https://github.com/amithkoujalgi/ollama4j.git
synced 2025-05-15 03:47:13 +02:00
Refactor code to enhance robustness and clarity
Refactored `OllamaChatMessageRole` to simplify custom role creation, guarding against nulls in `OllamaToolsResult`, and made `OllamaChatResult` properties immutable. Improved error handling in `OllamaAPI`, added verbose logs, and ensured safer JSON parsing for tool responses. Introduced `@JsonIgnoreProperties` for better deserialization support.
This commit is contained in:
parent
419b0369c9
commit
c8c30d703b
@ -1,14 +1,17 @@
|
|||||||
package io.github.ollama4j;
|
package io.github.ollama4j;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonParseException;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||||
import io.github.ollama4j.exceptions.RoleNotFoundException;
|
import io.github.ollama4j.exceptions.RoleNotFoundException;
|
||||||
import io.github.ollama4j.exceptions.ToolInvocationException;
|
import io.github.ollama4j.exceptions.ToolInvocationException;
|
||||||
import io.github.ollama4j.exceptions.ToolNotFoundException;
|
import io.github.ollama4j.exceptions.ToolNotFoundException;
|
||||||
import io.github.ollama4j.models.chat.*;
|
import io.github.ollama4j.models.chat.*;
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
|
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
|
||||||
|
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
|
import io.github.ollama4j.models.embeddings.OllamaEmbeddingResponseModel;
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||||
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
|
|
||||||
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
|
import io.github.ollama4j.models.generate.OllamaGenerateRequest;
|
||||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||||
import io.github.ollama4j.models.generate.OllamaTokenHandler;
|
import io.github.ollama4j.models.generate.OllamaTokenHandler;
|
||||||
@ -22,6 +25,12 @@ import io.github.ollama4j.tools.annotations.ToolSpec;
|
|||||||
import io.github.ollama4j.utils.Options;
|
import io.github.ollama4j.utils.Options;
|
||||||
import io.github.ollama4j.utils.Utils;
|
import io.github.ollama4j.utils.Utils;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
import org.jsoup.Jsoup;
|
||||||
|
import org.jsoup.nodes.Document;
|
||||||
|
import org.jsoup.nodes.Element;
|
||||||
|
import org.jsoup.select.Elements;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.lang.reflect.InvocationTargetException;
|
import java.lang.reflect.InvocationTargetException;
|
||||||
@ -39,13 +48,6 @@ import java.time.Duration;
|
|||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
import org.jsoup.Jsoup;
|
|
||||||
import org.jsoup.nodes.Document;
|
|
||||||
import org.jsoup.nodes.Element;
|
|
||||||
import org.jsoup.select.Elements;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The base Ollama API class.
|
* The base Ollama API class.
|
||||||
*/
|
*/
|
||||||
@ -92,6 +94,9 @@ public class OllamaAPI {
|
|||||||
} else {
|
} else {
|
||||||
this.host = host;
|
this.host = host;
|
||||||
}
|
}
|
||||||
|
if (this.verbose) {
|
||||||
|
logger.info("Ollama API initialized with host: " + this.host);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -658,7 +663,22 @@ public class OllamaAPI {
|
|||||||
toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
|
toolsResponse = toolsResponse.replace("[TOOL_CALLS]", "");
|
||||||
}
|
}
|
||||||
|
|
||||||
List<ToolFunctionCallSpec> toolFunctionCallSpecs = Utils.getObjectMapper().readValue(toolsResponse, Utils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class));
|
List<ToolFunctionCallSpec> toolFunctionCallSpecs = new ArrayList<>();
|
||||||
|
ObjectMapper objectMapper = Utils.getObjectMapper();
|
||||||
|
|
||||||
|
if (!toolsResponse.isEmpty()) {
|
||||||
|
try {
|
||||||
|
// Try to parse the string to see if it's a valid JSON
|
||||||
|
JsonNode jsonNode = objectMapper.readTree(toolsResponse);
|
||||||
|
} catch (JsonParseException e) {
|
||||||
|
logger.warn("Response from model does not contain any tool calls. Returning the response as is.");
|
||||||
|
return toolResult;
|
||||||
|
}
|
||||||
|
toolFunctionCallSpecs = objectMapper.readValue(
|
||||||
|
toolsResponse,
|
||||||
|
objectMapper.getTypeFactory().constructCollectionType(List.class, ToolFunctionCallSpec.class)
|
||||||
|
);
|
||||||
|
}
|
||||||
for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
|
for (ToolFunctionCallSpec toolFunctionCallSpec : toolFunctionCallSpecs) {
|
||||||
toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
|
toolResults.put(toolFunctionCallSpec, invokeTool(toolFunctionCallSpec));
|
||||||
}
|
}
|
||||||
@ -881,6 +901,9 @@ public class OllamaAPI {
|
|||||||
*/
|
*/
|
||||||
public void registerTool(Tools.ToolSpecification toolSpecification) {
|
public void registerTool(Tools.ToolSpecification toolSpecification) {
|
||||||
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
|
toolRegistry.addTool(toolSpecification.getFunctionName(), toolSpecification);
|
||||||
|
if (this.verbose) {
|
||||||
|
logger.debug("Registered tool: {}", toolSpecification.getFunctionName());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -1093,7 +1116,7 @@ public class OllamaAPI {
|
|||||||
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
|
logger.debug("Invoking function {} with arguments {}", methodName, arguments);
|
||||||
}
|
}
|
||||||
if (function == null) {
|
if (function == null) {
|
||||||
throw new ToolNotFoundException("No such tool: " + methodName);
|
throw new ToolNotFoundException("No such tool: " + methodName + ". Please register the tool before invoking it.");
|
||||||
}
|
}
|
||||||
return function.apply(arguments);
|
return function.apply(arguments);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
@ -28,9 +28,9 @@ public class OllamaChatMessageRole {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static OllamaChatMessageRole newCustomRole(String roleName) {
|
public static OllamaChatMessageRole newCustomRole(String roleName) {
|
||||||
OllamaChatMessageRole customRole = new OllamaChatMessageRole(roleName);
|
// OllamaChatMessageRole customRole = new OllamaChatMessageRole(roleName);
|
||||||
roles.add(customRole);
|
// roles.add(customRole);
|
||||||
return customRole;
|
return new OllamaChatMessageRole(roleName);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<OllamaChatMessageRole> getRoles() {
|
public static List<OllamaChatMessageRole> getRoles() {
|
||||||
|
@ -14,10 +14,9 @@ import static io.github.ollama4j.utils.Utils.getObjectMapper;
|
|||||||
@Getter
|
@Getter
|
||||||
public class OllamaChatResult {
|
public class OllamaChatResult {
|
||||||
|
|
||||||
|
private final List<OllamaChatMessage> chatHistory;
|
||||||
|
|
||||||
private List<OllamaChatMessage> chatHistory;
|
private final OllamaChatResponseModel responseModel;
|
||||||
|
|
||||||
private OllamaChatResponseModel responseModel;
|
|
||||||
|
|
||||||
public OllamaChatResult(OllamaChatResponseModel responseModel, List<OllamaChatMessage> chatHistory) {
|
public OllamaChatResult(OllamaChatResponseModel responseModel, List<OllamaChatMessage> chatHistory) {
|
||||||
this.chatHistory = chatHistory;
|
this.chatHistory = chatHistory;
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package io.github.ollama4j.tools;
|
package io.github.ollama4j.tools;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
@ -9,6 +10,7 @@ import java.util.Map;
|
|||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
|
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||||
public class OllamaToolCallsFunction
|
public class OllamaToolCallsFunction
|
||||||
{
|
{
|
||||||
private String name;
|
private String name;
|
||||||
|
@ -18,6 +18,9 @@ public class OllamaToolsResult {
|
|||||||
|
|
||||||
public List<ToolResult> getToolResults() {
|
public List<ToolResult> getToolResults() {
|
||||||
List<ToolResult> results = new ArrayList<>();
|
List<ToolResult> results = new ArrayList<>();
|
||||||
|
if (this.toolResults == null) {
|
||||||
|
return results;
|
||||||
|
}
|
||||||
for (Map.Entry<ToolFunctionCallSpec, Object> r : this.toolResults.entrySet()) {
|
for (Map.Entry<ToolFunctionCallSpec, Object> r : this.toolResults.entrySet()) {
|
||||||
results.add(new ToolResult(r.getKey().getName(), r.getKey().getArguments(), r.getValue()));
|
results.add(new ToolResult(r.getKey().getName(), r.getKey().getArguments(), r.getValue()));
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user