From 0eec2a9756f4eec92efe46d1aea8bbbca1293f1c Mon Sep 17 00:00:00 2001 From: Sebastiaan de Schaetzen Date: Wed, 21 Aug 2024 14:55:53 +0200 Subject: [PATCH] Arguments --- .../seeseemelk/llamascript/LlamaScript.java | 34 +++++++-------- .../llamascript/LlamaScriptWorker.java | 13 +++--- .../llamascript/arguments/ArgumentBag.java | 10 ++++- .../arguments/ListArgumentBag.java | 43 +++++++++++++++++++ .../arguments/NamedArgumentBag.java | 31 +++++++++++++ .../seeseemelk/llamascript/BooleanTests.java | 2 +- .../seeseemelk/llamascript/StringTests.java | 2 +- 7 files changed, 107 insertions(+), 28 deletions(-) create mode 100644 libllamascript/src/main/java/be/seeseemelk/llamascript/arguments/ListArgumentBag.java create mode 100644 libllamascript/src/main/java/be/seeseemelk/llamascript/arguments/NamedArgumentBag.java diff --git a/libllamascript/src/main/java/be/seeseemelk/llamascript/LlamaScript.java b/libllamascript/src/main/java/be/seeseemelk/llamascript/LlamaScript.java index 7f6d8f1..3ebaf92 100644 --- a/libllamascript/src/main/java/be/seeseemelk/llamascript/LlamaScript.java +++ b/libllamascript/src/main/java/be/seeseemelk/llamascript/LlamaScript.java @@ -1,37 +1,33 @@ package be.seeseemelk.llamascript; +import be.seeseemelk.llamascript.arguments.ArgumentBag; +import be.seeseemelk.llamascript.arguments.ListArgumentBag; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import java.util.HashMap; - @Slf4j @RequiredArgsConstructor public class LlamaScript { private final LlamaScriptWorker worker; - public T eval(String prompt, Class returnType, Object... arguments) { - var argumentMap = new HashMap(); - for (Object argument : arguments) { - argumentMap.put(argumentMap.size(), argument); - } + public T eval(String prompt, Class returnType, ArgumentBag arguments) { if (returnType.equals(Integer.class)) { //noinspection unchecked - return (T) evalInt(prompt, argumentMap); + return (T) evalInt(prompt, arguments); } else if (returnType.equals(String.class)) { //noinspection unchecked - return (T) evalString(prompt, argumentMap); + return (T) evalString(prompt, arguments); } else if (returnType.equals(Boolean.class)) { //noinspection unchecked - return (T) evalBool(prompt, argumentMap); + return (T) evalBool(prompt, arguments); } else { throw new LLamaScriptException("Unsupported return type"); } } - public String evalString(String prompt, HashMap argumentMap) { - var result = worker.eval(prompt, primitiveOutputType("a string"), argumentMap); + public String evalString(String prompt, ArgumentBag arguments) { + var result = worker.eval(prompt, primitiveOutputType("a string"), arguments); if (result instanceof String string) return string; else if (result != null) @@ -41,11 +37,11 @@ public class LlamaScript { } public String evalString(String prompt, Object... arguments) { - return eval(prompt, String.class, arguments); + return eval(prompt, String.class, ListArgumentBag.of(arguments)); } - public Integer evalInt(String prompt, HashMap argumentMap) { - var result = worker.eval(prompt, primitiveOutputType("an integer number"), argumentMap); + public Integer evalInt(String prompt, ArgumentBag arguments) { + var result = worker.eval(prompt, primitiveOutputType("an integer number"), arguments); if (result instanceof Integer integer) return integer; else if (result instanceof String string) { @@ -60,11 +56,11 @@ public class LlamaScript { } public int evalInt(String prompt, Object... arguments) { - return eval(prompt, Integer.class, arguments); + return eval(prompt, Integer.class, ListArgumentBag.of(arguments)); } - public Boolean evalBool(String prompt, HashMap argumentMap) { - var result = worker.eval(prompt, primitiveOutputType("a boolean value"), argumentMap); + public Boolean evalBool(String prompt, ArgumentBag arguments) { + var result = worker.eval(prompt, primitiveOutputType("a boolean value"), arguments); if (result instanceof Boolean bool) return bool; else if (result instanceof String string) @@ -74,7 +70,7 @@ public class LlamaScript { } public boolean evalBool(String prompt, Object... arguments) { - return eval(prompt, Boolean.class, arguments); + return eval(prompt, Boolean.class, ListArgumentBag.of(arguments)); } private static String primitiveOutputType(String description) { diff --git a/libllamascript/src/main/java/be/seeseemelk/llamascript/LlamaScriptWorker.java b/libllamascript/src/main/java/be/seeseemelk/llamascript/LlamaScriptWorker.java index 7021a2e..7bff910 100644 --- a/libllamascript/src/main/java/be/seeseemelk/llamascript/LlamaScriptWorker.java +++ b/libllamascript/src/main/java/be/seeseemelk/llamascript/LlamaScriptWorker.java @@ -1,5 +1,7 @@ package be.seeseemelk.llamascript; +import be.seeseemelk.llamascript.arguments.ArgumentBag; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.databind.ObjectMapper; import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.exceptions.OllamaBaseException; @@ -11,7 +13,6 @@ import lombok.Builder; import java.io.IOException; import java.util.ArrayList; -import java.util.HashMap; @Builder public class LlamaScriptWorker { @@ -19,24 +20,23 @@ public class LlamaScriptWorker { private final String model; private final ObjectMapper mapper = new ObjectMapper(); - public Object eval(String prompt, String outputType, HashMap argumentMap) { + public Object eval(String prompt, String outputType, ArgumentBag arguments) { try { var optionsBuilder = new OptionsBuilder(); optionsBuilder.setSeed(20); // Chosen with 1d20. Yes, a nat 20! Let's go baby! - optionsBuilder.setTemperature(0f); + optionsBuilder.setTemperature(0); var options = optionsBuilder.build(); var fullPrompt = new StringBuilder(""" You will receive a prompt, and will have to consider the question in the prompt with regards to the inputs you will receive. Your output must be a JSON object containing the single value 'result'. + Your answer must be correct. Do not make mistakes %s The prompt: %s Arguments: """.formatted(outputType, prompt)); - for (var entry : argumentMap.entrySet()) { - fullPrompt.append(" - ").append(entry.getKey()).append(": ").append(entry.getValue()).append("\n"); - } + arguments.writeTo(fullPrompt); System.out.printf("Prompt is: %s\n", fullPrompt); var messages = new ArrayList(); @@ -53,6 +53,7 @@ public class LlamaScriptWorker { } } + @JsonIgnoreProperties(ignoreUnknown = true) private static class Result { public Object result; } diff --git a/libllamascript/src/main/java/be/seeseemelk/llamascript/arguments/ArgumentBag.java b/libllamascript/src/main/java/be/seeseemelk/llamascript/arguments/ArgumentBag.java index 8ab82b9..99e3801 100644 --- a/libllamascript/src/main/java/be/seeseemelk/llamascript/arguments/ArgumentBag.java +++ b/libllamascript/src/main/java/be/seeseemelk/llamascript/arguments/ArgumentBag.java @@ -1,4 +1,12 @@ package be.seeseemelk.llamascript.arguments; -public abstract class ArgumentBag { +/** + * A bag of arguments that can be added to an evaluation request. + */ +public interface ArgumentBag { + /** + * Writes the values inside the argument bag to a StringBuilder. + * @param builder The StringBuilder to write to. + */ + void writeTo(StringBuilder builder); } diff --git a/libllamascript/src/main/java/be/seeseemelk/llamascript/arguments/ListArgumentBag.java b/libllamascript/src/main/java/be/seeseemelk/llamascript/arguments/ListArgumentBag.java new file mode 100644 index 0000000..f6a4084 --- /dev/null +++ b/libllamascript/src/main/java/be/seeseemelk/llamascript/arguments/ListArgumentBag.java @@ -0,0 +1,43 @@ +package be.seeseemelk.llamascript.arguments; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +/** + * A bag of arguments that will keep the order in which the arguments were added. + * The arguments are added with an index, which starts at one for the first argument. + */ +public class ListArgumentBag implements ArgumentBag { + private final List values = new ArrayList<>(); + + /** + * Adds an argument to the bag. + * @param value The argument to add. + * @return This instance. + */ + public ListArgumentBag add(Object value) { + values.add(value); + return this; + } + + @Override + public void writeTo(StringBuilder builder) { + for (int i = 0; i < values.size(); i++) { + builder.append("[").append(i + 1).append("] -> ").append(Objects.toString(values.get(i))).append('\n'); + } + } + + /** + * Creates an argument bag from an array of values. + * @param values The values to add. + * @return The bag with the arguments added. + */ + public static ListArgumentBag of(Object... values) { + var bag = new ListArgumentBag(); + for (var value : values) { + bag.add(value); + } + return bag; + } +} diff --git a/libllamascript/src/main/java/be/seeseemelk/llamascript/arguments/NamedArgumentBag.java b/libllamascript/src/main/java/be/seeseemelk/llamascript/arguments/NamedArgumentBag.java new file mode 100644 index 0000000..58ae7e7 --- /dev/null +++ b/libllamascript/src/main/java/be/seeseemelk/llamascript/arguments/NamedArgumentBag.java @@ -0,0 +1,31 @@ +package be.seeseemelk.llamascript.arguments; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * A bag of arguments that will keep the order in which the arguments were added. + * The arguments are added with an index, which starts at one for the first argument. + */ +public class NamedArgumentBag implements ArgumentBag { + private final Map values = new HashMap<>(); + + /** + * Adds an argument to the bag. + * @param value The argument to add. + * @return This instance. + */ + public NamedArgumentBag set(String name, Object value) { + values.put(name, value); + return this; + } + + @Override + public void writeTo(StringBuilder builder) { + for (var entry : values.entrySet()) { + var string = Objects.toString(entry.getValue()); + builder.append(entry.getKey()).append(" -> ").append(string).append('\n'); + } + } +} diff --git a/libllamascript/src/test/java/be/seeseemelk/llamascript/BooleanTests.java b/libllamascript/src/test/java/be/seeseemelk/llamascript/BooleanTests.java index 3462822..1489751 100644 --- a/libllamascript/src/test/java/be/seeseemelk/llamascript/BooleanTests.java +++ b/libllamascript/src/test/java/be/seeseemelk/llamascript/BooleanTests.java @@ -11,7 +11,7 @@ public class BooleanTests extends AbstractLlamaTest { @ParameterizedTest @ValueSource(ints = {0, 1, 2, 3, 4, 5}) void isEven(int number) { - boolean isEven = llama.evalBool("Return true if the number is even", number); + boolean isEven = llama.evalBool("Return true if the number is even, false if it's odd. (btw, zero is even)", number); assertThat(isEven, equalTo(number % 2 == 0)); } diff --git a/libllamascript/src/test/java/be/seeseemelk/llamascript/StringTests.java b/libllamascript/src/test/java/be/seeseemelk/llamascript/StringTests.java index 789ebd4..104ffa5 100644 --- a/libllamascript/src/test/java/be/seeseemelk/llamascript/StringTests.java +++ b/libllamascript/src/test/java/be/seeseemelk/llamascript/StringTests.java @@ -20,7 +20,7 @@ public class StringTests extends AbstractLlamaTest { @Test void canInvert() { - String value = llama.evalString("Return the string backwards", "Cat"); + String value = llama.evalString("Return the string backwards. Do not change the capitalisation of any letters. E.g.: Za becomes aZ", "Cat"); assertThat(value, equalTo("taC")); }