Arguments

This commit is contained in:
Sebastiaan de Schaetzen 2024-08-21 14:55:53 +02:00
parent 14d13a7d76
commit 0eec2a9756
7 changed files with 107 additions and 28 deletions

View File

@ -1,37 +1,33 @@
package be.seeseemelk.llamascript; package be.seeseemelk.llamascript;
import be.seeseemelk.llamascript.arguments.ArgumentBag;
import be.seeseemelk.llamascript.arguments.ListArgumentBag;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import java.util.HashMap;
@Slf4j @Slf4j
@RequiredArgsConstructor @RequiredArgsConstructor
public class LlamaScript { public class LlamaScript {
private final LlamaScriptWorker worker; private final LlamaScriptWorker worker;
public <T> T eval(String prompt, Class<T> returnType, Object... arguments) { public <T> T eval(String prompt, Class<T> returnType, ArgumentBag arguments) {
var argumentMap = new HashMap<Integer, Object>();
for (Object argument : arguments) {
argumentMap.put(argumentMap.size(), argument);
}
if (returnType.equals(Integer.class)) { if (returnType.equals(Integer.class)) {
//noinspection unchecked //noinspection unchecked
return (T) evalInt(prompt, argumentMap); return (T) evalInt(prompt, arguments);
} else if (returnType.equals(String.class)) { } else if (returnType.equals(String.class)) {
//noinspection unchecked //noinspection unchecked
return (T) evalString(prompt, argumentMap); return (T) evalString(prompt, arguments);
} else if (returnType.equals(Boolean.class)) { } else if (returnType.equals(Boolean.class)) {
//noinspection unchecked //noinspection unchecked
return (T) evalBool(prompt, argumentMap); return (T) evalBool(prompt, arguments);
} else { } else {
throw new LLamaScriptException("Unsupported return type"); throw new LLamaScriptException("Unsupported return type");
} }
} }
public String evalString(String prompt, HashMap<Integer, Object> argumentMap) { public String evalString(String prompt, ArgumentBag arguments) {
var result = worker.eval(prompt, primitiveOutputType("a string"), argumentMap); var result = worker.eval(prompt, primitiveOutputType("a string"), arguments);
if (result instanceof String string) if (result instanceof String string)
return string; return string;
else if (result != null) else if (result != null)
@ -41,11 +37,11 @@ public class LlamaScript {
} }
public String evalString(String prompt, Object... arguments) { public String evalString(String prompt, Object... arguments) {
return eval(prompt, String.class, arguments); return eval(prompt, String.class, ListArgumentBag.of(arguments));
} }
public Integer evalInt(String prompt, HashMap<Integer, Object> argumentMap) { public Integer evalInt(String prompt, ArgumentBag arguments) {
var result = worker.eval(prompt, primitiveOutputType("an integer number"), argumentMap); var result = worker.eval(prompt, primitiveOutputType("an integer number"), arguments);
if (result instanceof Integer integer) if (result instanceof Integer integer)
return integer; return integer;
else if (result instanceof String string) { else if (result instanceof String string) {
@ -60,11 +56,11 @@ public class LlamaScript {
} }
public int evalInt(String prompt, Object... arguments) { public int evalInt(String prompt, Object... arguments) {
return eval(prompt, Integer.class, arguments); return eval(prompt, Integer.class, ListArgumentBag.of(arguments));
} }
public Boolean evalBool(String prompt, HashMap<Integer, Object> argumentMap) { public Boolean evalBool(String prompt, ArgumentBag arguments) {
var result = worker.eval(prompt, primitiveOutputType("a boolean value"), argumentMap); var result = worker.eval(prompt, primitiveOutputType("a boolean value"), arguments);
if (result instanceof Boolean bool) if (result instanceof Boolean bool)
return bool; return bool;
else if (result instanceof String string) else if (result instanceof String string)
@ -74,7 +70,7 @@ public class LlamaScript {
} }
public boolean evalBool(String prompt, Object... arguments) { public boolean evalBool(String prompt, Object... arguments) {
return eval(prompt, Boolean.class, arguments); return eval(prompt, Boolean.class, ListArgumentBag.of(arguments));
} }
private static String primitiveOutputType(String description) { private static String primitiveOutputType(String description) {

View File

@ -1,5 +1,7 @@
package be.seeseemelk.llamascript; package be.seeseemelk.llamascript;
import be.seeseemelk.llamascript.arguments.ArgumentBag;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.OllamaBaseException; import io.github.ollama4j.exceptions.OllamaBaseException;
@ -11,7 +13,6 @@ import lombok.Builder;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
@Builder @Builder
public class LlamaScriptWorker { public class LlamaScriptWorker {
@ -19,24 +20,23 @@ public class LlamaScriptWorker {
private final String model; private final String model;
private final ObjectMapper mapper = new ObjectMapper(); private final ObjectMapper mapper = new ObjectMapper();
public Object eval(String prompt, String outputType, HashMap<Integer, Object> argumentMap) { public Object eval(String prompt, String outputType, ArgumentBag arguments) {
try { try {
var optionsBuilder = new OptionsBuilder(); var optionsBuilder = new OptionsBuilder();
optionsBuilder.setSeed(20); // Chosen with 1d20. Yes, a nat 20! Let's go baby! optionsBuilder.setSeed(20); // Chosen with 1d20. Yes, a nat 20! Let's go baby!
optionsBuilder.setTemperature(0f); optionsBuilder.setTemperature(0);
var options = optionsBuilder.build(); var options = optionsBuilder.build();
var fullPrompt = new StringBuilder(""" var fullPrompt = new StringBuilder("""
You will receive a prompt, and will have to consider the question in the prompt with regards to the inputs you will receive. You will receive a prompt, and will have to consider the question in the prompt with regards to the inputs you will receive.
Your output must be a JSON object containing the single value 'result'. Your output must be a JSON object containing the single value 'result'.
Your answer must be correct. Do not make mistakes
%s %s
The prompt: %s The prompt: %s
Arguments: Arguments:
""".formatted(outputType, prompt)); """.formatted(outputType, prompt));
for (var entry : argumentMap.entrySet()) { arguments.writeTo(fullPrompt);
fullPrompt.append(" - ").append(entry.getKey()).append(": ").append(entry.getValue()).append("\n");
}
System.out.printf("Prompt is: %s\n", fullPrompt); System.out.printf("Prompt is: %s\n", fullPrompt);
var messages = new ArrayList<OllamaChatMessage>(); var messages = new ArrayList<OllamaChatMessage>();
@ -53,6 +53,7 @@ public class LlamaScriptWorker {
} }
} }
@JsonIgnoreProperties(ignoreUnknown = true)
private static class Result { private static class Result {
public Object result; public Object result;
} }

View File

@ -1,4 +1,12 @@
package be.seeseemelk.llamascript.arguments; package be.seeseemelk.llamascript.arguments;
public abstract class ArgumentBag { /**
* A bag of arguments that can be added to an evaluation request.
*/
public interface ArgumentBag {
/**
* Writes the values inside the argument bag to a StringBuilder.
* @param builder The StringBuilder to write to.
*/
void writeTo(StringBuilder builder);
} }

View File

@ -0,0 +1,43 @@
package be.seeseemelk.llamascript.arguments;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
/**
* A bag of arguments that will keep the order in which the arguments were added.
* The arguments are added with an index, which starts at one for the first argument.
*/
public class ListArgumentBag implements ArgumentBag {
private final List<Object> values = new ArrayList<>();
/**
* Adds an argument to the bag.
* @param value The argument to add.
* @return This instance.
*/
public ListArgumentBag add(Object value) {
values.add(value);
return this;
}
@Override
public void writeTo(StringBuilder builder) {
for (int i = 0; i < values.size(); i++) {
builder.append("[").append(i + 1).append("] -> ").append(Objects.toString(values.get(i))).append('\n');
}
}
/**
* Creates an argument bag from an array of values.
* @param values The values to add.
* @return The bag with the arguments added.
*/
public static ListArgumentBag of(Object... values) {
var bag = new ListArgumentBag();
for (var value : values) {
bag.add(value);
}
return bag;
}
}

View File

@ -0,0 +1,31 @@
package be.seeseemelk.llamascript.arguments;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* A bag of arguments that will keep the order in which the arguments were added.
* The arguments are added with an index, which starts at one for the first argument.
*/
public class NamedArgumentBag implements ArgumentBag {
private final Map<String, Object> values = new HashMap<>();
/**
* Adds an argument to the bag.
* @param value The argument to add.
* @return This instance.
*/
public NamedArgumentBag set(String name, Object value) {
values.put(name, value);
return this;
}
@Override
public void writeTo(StringBuilder builder) {
for (var entry : values.entrySet()) {
var string = Objects.toString(entry.getValue());
builder.append(entry.getKey()).append(" -> ").append(string).append('\n');
}
}
}

View File

@ -11,7 +11,7 @@ public class BooleanTests extends AbstractLlamaTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(ints = {0, 1, 2, 3, 4, 5}) @ValueSource(ints = {0, 1, 2, 3, 4, 5})
void isEven(int number) { void isEven(int number) {
boolean isEven = llama.evalBool("Return true if the number is even", number); boolean isEven = llama.evalBool("Return true if the number is even, false if it's odd. (btw, zero is even)", number);
assertThat(isEven, equalTo(number % 2 == 0)); assertThat(isEven, equalTo(number % 2 == 0));
} }

View File

@ -20,7 +20,7 @@ public class StringTests extends AbstractLlamaTest {
@Test @Test
void canInvert() { void canInvert() {
String value = llama.evalString("Return the string backwards", "Cat"); String value = llama.evalString("Return the string backwards. Do not change the capitalisation of any letters. E.g.: Za becomes aZ", "Cat");
assertThat(value, equalTo("taC")); assertThat(value, equalTo("taC"));
} }