Arguments
This commit is contained in:
parent
14d13a7d76
commit
0eec2a9756
@ -1,37 +1,33 @@
|
||||
package be.seeseemelk.llamascript;
|
||||
|
||||
import be.seeseemelk.llamascript.arguments.ArgumentBag;
|
||||
import be.seeseemelk.llamascript.arguments.ListArgumentBag;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class LlamaScript {
|
||||
private final LlamaScriptWorker worker;
|
||||
|
||||
public <T> T eval(String prompt, Class<T> returnType, Object... arguments) {
|
||||
var argumentMap = new HashMap<Integer, Object>();
|
||||
for (Object argument : arguments) {
|
||||
argumentMap.put(argumentMap.size(), argument);
|
||||
}
|
||||
public <T> T eval(String prompt, Class<T> returnType, ArgumentBag arguments) {
|
||||
|
||||
if (returnType.equals(Integer.class)) {
|
||||
//noinspection unchecked
|
||||
return (T) evalInt(prompt, argumentMap);
|
||||
return (T) evalInt(prompt, arguments);
|
||||
} else if (returnType.equals(String.class)) {
|
||||
//noinspection unchecked
|
||||
return (T) evalString(prompt, argumentMap);
|
||||
return (T) evalString(prompt, arguments);
|
||||
} else if (returnType.equals(Boolean.class)) {
|
||||
//noinspection unchecked
|
||||
return (T) evalBool(prompt, argumentMap);
|
||||
return (T) evalBool(prompt, arguments);
|
||||
} else {
|
||||
throw new LLamaScriptException("Unsupported return type");
|
||||
}
|
||||
}
|
||||
|
||||
public String evalString(String prompt, HashMap<Integer, Object> argumentMap) {
|
||||
var result = worker.eval(prompt, primitiveOutputType("a string"), argumentMap);
|
||||
public String evalString(String prompt, ArgumentBag arguments) {
|
||||
var result = worker.eval(prompt, primitiveOutputType("a string"), arguments);
|
||||
if (result instanceof String string)
|
||||
return string;
|
||||
else if (result != null)
|
||||
@ -41,11 +37,11 @@ public class LlamaScript {
|
||||
}
|
||||
|
||||
public String evalString(String prompt, Object... arguments) {
|
||||
return eval(prompt, String.class, arguments);
|
||||
return eval(prompt, String.class, ListArgumentBag.of(arguments));
|
||||
}
|
||||
|
||||
public Integer evalInt(String prompt, HashMap<Integer, Object> argumentMap) {
|
||||
var result = worker.eval(prompt, primitiveOutputType("an integer number"), argumentMap);
|
||||
public Integer evalInt(String prompt, ArgumentBag arguments) {
|
||||
var result = worker.eval(prompt, primitiveOutputType("an integer number"), arguments);
|
||||
if (result instanceof Integer integer)
|
||||
return integer;
|
||||
else if (result instanceof String string) {
|
||||
@ -60,11 +56,11 @@ public class LlamaScript {
|
||||
}
|
||||
|
||||
public int evalInt(String prompt, Object... arguments) {
|
||||
return eval(prompt, Integer.class, arguments);
|
||||
return eval(prompt, Integer.class, ListArgumentBag.of(arguments));
|
||||
}
|
||||
|
||||
public Boolean evalBool(String prompt, HashMap<Integer, Object> argumentMap) {
|
||||
var result = worker.eval(prompt, primitiveOutputType("a boolean value"), argumentMap);
|
||||
public Boolean evalBool(String prompt, ArgumentBag arguments) {
|
||||
var result = worker.eval(prompt, primitiveOutputType("a boolean value"), arguments);
|
||||
if (result instanceof Boolean bool)
|
||||
return bool;
|
||||
else if (result instanceof String string)
|
||||
@ -74,7 +70,7 @@ public class LlamaScript {
|
||||
}
|
||||
|
||||
public boolean evalBool(String prompt, Object... arguments) {
|
||||
return eval(prompt, Boolean.class, arguments);
|
||||
return eval(prompt, Boolean.class, ListArgumentBag.of(arguments));
|
||||
}
|
||||
|
||||
private static String primitiveOutputType(String description) {
|
||||
|
@ -1,5 +1,7 @@
|
||||
package be.seeseemelk.llamascript;
|
||||
|
||||
import be.seeseemelk.llamascript.arguments.ArgumentBag;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
@ -11,7 +13,6 @@ import lombok.Builder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
|
||||
@Builder
|
||||
public class LlamaScriptWorker {
|
||||
@ -19,24 +20,23 @@ public class LlamaScriptWorker {
|
||||
private final String model;
|
||||
private final ObjectMapper mapper = new ObjectMapper();
|
||||
|
||||
public Object eval(String prompt, String outputType, HashMap<Integer, Object> argumentMap) {
|
||||
public Object eval(String prompt, String outputType, ArgumentBag arguments) {
|
||||
try {
|
||||
var optionsBuilder = new OptionsBuilder();
|
||||
optionsBuilder.setSeed(20); // Chosen with 1d20. Yes, a nat 20! Let's go baby!
|
||||
optionsBuilder.setTemperature(0f);
|
||||
optionsBuilder.setTemperature(0);
|
||||
var options = optionsBuilder.build();
|
||||
var fullPrompt = new StringBuilder("""
|
||||
You will receive a prompt, and will have to consider the question in the prompt with regards to the inputs you will receive.
|
||||
Your output must be a JSON object containing the single value 'result'.
|
||||
Your answer must be correct. Do not make mistakes
|
||||
%s
|
||||
|
||||
The prompt: %s
|
||||
|
||||
Arguments:
|
||||
""".formatted(outputType, prompt));
|
||||
for (var entry : argumentMap.entrySet()) {
|
||||
fullPrompt.append(" - ").append(entry.getKey()).append(": ").append(entry.getValue()).append("\n");
|
||||
}
|
||||
arguments.writeTo(fullPrompt);
|
||||
System.out.printf("Prompt is: %s\n", fullPrompt);
|
||||
|
||||
var messages = new ArrayList<OllamaChatMessage>();
|
||||
@ -53,6 +53,7 @@ public class LlamaScriptWorker {
|
||||
}
|
||||
}
|
||||
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
private static class Result {
|
||||
public Object result;
|
||||
}
|
||||
|
@ -1,4 +1,12 @@
|
||||
package be.seeseemelk.llamascript.arguments;
|
||||
|
||||
public abstract class ArgumentBag {
|
||||
/**
|
||||
* A bag of arguments that can be added to an evaluation request.
|
||||
*/
|
||||
public interface ArgumentBag {
|
||||
/**
|
||||
* Writes the values inside the argument bag to a StringBuilder.
|
||||
* @param builder The StringBuilder to write to.
|
||||
*/
|
||||
void writeTo(StringBuilder builder);
|
||||
}
|
||||
|
@ -0,0 +1,43 @@
|
||||
package be.seeseemelk.llamascript.arguments;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* A bag of arguments that will keep the order in which the arguments were added.
|
||||
* The arguments are added with an index, which starts at one for the first argument.
|
||||
*/
|
||||
public class ListArgumentBag implements ArgumentBag {
|
||||
private final List<Object> values = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* Adds an argument to the bag.
|
||||
* @param value The argument to add.
|
||||
* @return This instance.
|
||||
*/
|
||||
public ListArgumentBag add(Object value) {
|
||||
values.add(value);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StringBuilder builder) {
|
||||
for (int i = 0; i < values.size(); i++) {
|
||||
builder.append("[").append(i + 1).append("] -> ").append(Objects.toString(values.get(i))).append('\n');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an argument bag from an array of values.
|
||||
* @param values The values to add.
|
||||
* @return The bag with the arguments added.
|
||||
*/
|
||||
public static ListArgumentBag of(Object... values) {
|
||||
var bag = new ListArgumentBag();
|
||||
for (var value : values) {
|
||||
bag.add(value);
|
||||
}
|
||||
return bag;
|
||||
}
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
package be.seeseemelk.llamascript.arguments;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* A bag of arguments that will keep the order in which the arguments were added.
|
||||
* The arguments are added with an index, which starts at one for the first argument.
|
||||
*/
|
||||
public class NamedArgumentBag implements ArgumentBag {
|
||||
private final Map<String, Object> values = new HashMap<>();
|
||||
|
||||
/**
|
||||
* Adds an argument to the bag.
|
||||
* @param value The argument to add.
|
||||
* @return This instance.
|
||||
*/
|
||||
public NamedArgumentBag set(String name, Object value) {
|
||||
values.put(name, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StringBuilder builder) {
|
||||
for (var entry : values.entrySet()) {
|
||||
var string = Objects.toString(entry.getValue());
|
||||
builder.append(entry.getKey()).append(" -> ").append(string).append('\n');
|
||||
}
|
||||
}
|
||||
}
|
@ -11,7 +11,7 @@ public class BooleanTests extends AbstractLlamaTest {
|
||||
@ParameterizedTest
|
||||
@ValueSource(ints = {0, 1, 2, 3, 4, 5})
|
||||
void isEven(int number) {
|
||||
boolean isEven = llama.evalBool("Return true if the number is even", number);
|
||||
boolean isEven = llama.evalBool("Return true if the number is even, false if it's odd. (btw, zero is even)", number);
|
||||
assertThat(isEven, equalTo(number % 2 == 0));
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,7 @@ public class StringTests extends AbstractLlamaTest {
|
||||
|
||||
@Test
|
||||
void canInvert() {
|
||||
String value = llama.evalString("Return the string backwards", "Cat");
|
||||
String value = llama.evalString("Return the string backwards. Do not change the capitalisation of any letters. E.g.: Za becomes aZ", "Cat");
|
||||
assertThat(value, equalTo("taC"));
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user