diff --git a/src/main/java/be/seeseemelk/llamascript/LLamaScriptFactory.java b/src/main/java/be/seeseemelk/llamascript/LLamaScriptFactory.java index b40f3d4..5dcffea 100644 --- a/src/main/java/be/seeseemelk/llamascript/LLamaScriptFactory.java +++ b/src/main/java/be/seeseemelk/llamascript/LLamaScriptFactory.java @@ -3,12 +3,26 @@ package be.seeseemelk.llamascript; import io.github.ollama4j.OllamaAPI; import lombok.Setter; +import java.util.Objects; + @Setter public class LLamaScriptFactory { - private String host = "http://localhost:11434"; - private String model = "llama3.1:8b"; + private String host; + private String model; + + public LLamaScriptFactory() { + host = Objects.requireNonNullElse(System.getenv("OLLAMA_HOST"), "http://localhost:11434"); + model = Objects.requireNonNullElse(System.getenv("OLLAMA_MODEL"), "llama3.1:8b"); + } public LlamaScript build() { - return new LlamaScript(new OllamaAPI(host), model); + return new LlamaScript(buildWorker()); + } + + private LlamaScriptWorker buildWorker() { + return LlamaScriptWorker.builder() + .api(new OllamaAPI(host)) + .model(model) + .build(); } } diff --git a/src/main/java/be/seeseemelk/llamascript/LlamaScript.java b/src/main/java/be/seeseemelk/llamascript/LlamaScript.java index 30b1f79..7f6d8f1 100644 --- a/src/main/java/be/seeseemelk/llamascript/LlamaScript.java +++ b/src/main/java/be/seeseemelk/llamascript/LlamaScript.java @@ -1,25 +1,14 @@ package be.seeseemelk.llamascript; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.github.ollama4j.OllamaAPI; -import io.github.ollama4j.exceptions.OllamaBaseException; -import io.github.ollama4j.models.chat.OllamaChatMessage; -import io.github.ollama4j.models.chat.OllamaChatMessageRole; -import io.github.ollama4j.models.chat.OllamaChatRequest; -import io.github.ollama4j.utils.OptionsBuilder; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import java.io.IOException; -import java.util.ArrayList; import java.util.HashMap; @Slf4j @RequiredArgsConstructor public class LlamaScript { - private final OllamaAPI api; - private final String model; - private final ObjectMapper mapper = new ObjectMapper(); + private final LlamaScriptWorker worker; public T eval(String prompt, Class returnType, Object... arguments) { var argumentMap = new HashMap(); @@ -33,13 +22,16 @@ public class LlamaScript { } else if (returnType.equals(String.class)) { //noinspection unchecked return (T) evalString(prompt, argumentMap); + } else if (returnType.equals(Boolean.class)) { + //noinspection unchecked + return (T) evalBool(prompt, argumentMap); } else { throw new LLamaScriptException("Unsupported return type"); } } - private String evalString(String prompt, HashMap argumentMap) { - var result = eval(prompt, primitiveOutputType("a string"), argumentMap); + public String evalString(String prompt, HashMap argumentMap) { + var result = worker.eval(prompt, primitiveOutputType("a string"), argumentMap); if (result instanceof String string) return string; else if (result != null) @@ -48,8 +40,12 @@ public class LlamaScript { throw new LLamaScriptException("No response"); } - private Integer evalInt(String prompt, HashMap argumentMap) { - var result = eval(prompt, primitiveOutputType("an integer number"), argumentMap); + public String evalString(String prompt, Object... arguments) { + return eval(prompt, String.class, arguments); + } + + public Integer evalInt(String prompt, HashMap argumentMap) { + var result = worker.eval(prompt, primitiveOutputType("an integer number"), argumentMap); if (result instanceof Integer integer) return integer; else if (result instanceof String string) { @@ -63,42 +59,22 @@ public class LlamaScript { throw new LLamaScriptException("Invalid integer type"); } - private Object eval(String prompt, String outputType, HashMap argumentMap) { - try { - var optionsBuilder = new OptionsBuilder(); - optionsBuilder.setSeed(20); // Chosen with 1d20. Yes, a nat 20! Let's go baby! - optionsBuilder.setTemperature(0f); - var options = optionsBuilder.build(); - var fullPrompt = new StringBuilder(""" - You will receive a prompt, and will have to consider the question in the prompt with regards to the inputs you will receive. - Your output must be a JSON object containing the single value 'result'. - %s - - The prompt: %s - - Arguments: - """.formatted(outputType, prompt)); - for (var entry : argumentMap.entrySet()) { - fullPrompt.append(" - ").append(entry.getKey()).append(": ").append(entry.getValue()).append("\n"); - } - System.out.printf("Prompt is: %s\n", fullPrompt); - - var messages = new ArrayList(); - messages.add(new OllamaChatMessage(OllamaChatMessageRole.USER, fullPrompt.toString())); - var chatRequest = new OllamaChatRequest(model, messages); - chatRequest.setReturnFormatJson(true); - chatRequest.setOptions(options.getOptionsMap()); - - var result = api.chat(chatRequest); - System.out.printf("Result is: %s\n", result.getResponse()); - return mapper.readValue(result.getResponse(), Result.class).result; - } catch (OllamaBaseException | InterruptedException | IOException e) { - throw new LLamaScriptException(e); - } + public int evalInt(String prompt, Object... arguments) { + return eval(prompt, Integer.class, arguments); } - private static class Result { - public Object result; + public Boolean evalBool(String prompt, HashMap argumentMap) { + var result = worker.eval(prompt, primitiveOutputType("a boolean value"), argumentMap); + if (result instanceof Boolean bool) + return bool; + else if (result instanceof String string) + return Boolean.parseBoolean(string); + else + throw new LLamaScriptException("Invalid boolean type"); + } + + public boolean evalBool(String prompt, Object... arguments) { + return eval(prompt, Boolean.class, arguments); } private static String primitiveOutputType(String description) { diff --git a/src/main/java/be/seeseemelk/llamascript/LlamaScriptWorker.java b/src/main/java/be/seeseemelk/llamascript/LlamaScriptWorker.java new file mode 100644 index 0000000..7021a2e --- /dev/null +++ b/src/main/java/be/seeseemelk/llamascript/LlamaScriptWorker.java @@ -0,0 +1,59 @@ +package be.seeseemelk.llamascript; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.github.ollama4j.OllamaAPI; +import io.github.ollama4j.exceptions.OllamaBaseException; +import io.github.ollama4j.models.chat.OllamaChatMessage; +import io.github.ollama4j.models.chat.OllamaChatMessageRole; +import io.github.ollama4j.models.chat.OllamaChatRequest; +import io.github.ollama4j.utils.OptionsBuilder; +import lombok.Builder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; + +@Builder +public class LlamaScriptWorker { + private final OllamaAPI api; + private final String model; + private final ObjectMapper mapper = new ObjectMapper(); + + public Object eval(String prompt, String outputType, HashMap argumentMap) { + try { + var optionsBuilder = new OptionsBuilder(); + optionsBuilder.setSeed(20); // Chosen with 1d20. Yes, a nat 20! Let's go baby! + optionsBuilder.setTemperature(0f); + var options = optionsBuilder.build(); + var fullPrompt = new StringBuilder(""" + You will receive a prompt, and will have to consider the question in the prompt with regards to the inputs you will receive. + Your output must be a JSON object containing the single value 'result'. + %s + + The prompt: %s + + Arguments: + """.formatted(outputType, prompt)); + for (var entry : argumentMap.entrySet()) { + fullPrompt.append(" - ").append(entry.getKey()).append(": ").append(entry.getValue()).append("\n"); + } + System.out.printf("Prompt is: %s\n", fullPrompt); + + var messages = new ArrayList(); + messages.add(new OllamaChatMessage(OllamaChatMessageRole.USER, fullPrompt.toString())); + var chatRequest = new OllamaChatRequest(model, messages); + chatRequest.setReturnFormatJson(true); + chatRequest.setOptions(options.getOptionsMap()); + + var result = api.chat(chatRequest); + System.out.printf("Result is: %s\n", result.getResponse()); + return mapper.readValue(result.getResponse(), Result.class).result; + } catch (OllamaBaseException | InterruptedException | IOException e) { + throw new LLamaScriptException(e); + } + } + + private static class Result { + public Object result; + } +} diff --git a/src/main/java/be/seeseemelk/llamascript/arguments/ArgumentBag.java b/src/main/java/be/seeseemelk/llamascript/arguments/ArgumentBag.java new file mode 100644 index 0000000..8ab82b9 --- /dev/null +++ b/src/main/java/be/seeseemelk/llamascript/arguments/ArgumentBag.java @@ -0,0 +1,4 @@ +package be.seeseemelk.llamascript.arguments; + +public abstract class ArgumentBag { +} diff --git a/src/test/java/be/seeseemelk/llamascript/BooleanTests.java b/src/test/java/be/seeseemelk/llamascript/BooleanTests.java new file mode 100644 index 0000000..3462822 --- /dev/null +++ b/src/test/java/be/seeseemelk/llamascript/BooleanTests.java @@ -0,0 +1,26 @@ +package be.seeseemelk.llamascript; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +public class BooleanTests extends AbstractLlamaTest { + @ParameterizedTest + @ValueSource(ints = {0, 1, 2, 3, 4, 5}) + void isEven(int number) { + boolean isEven = llama.evalBool("Return true if the number is even", number); + assertThat(isEven, equalTo(number % 2 == 0)); + } + + @Test + void isABuilding() { + boolean isBuilding = llama.evalBool("Return true if the argument is a building or skyscraper", "Dalai Lama"); + assertThat(isBuilding, equalTo(false)); + + isBuilding = llama.evalBool("Return true if the argument is a building or skyscraper", "Burj Khalifa"); + assertThat(isBuilding, equalTo(true)); + } +} diff --git a/src/test/java/be/seeseemelk/llamascript/IntegerTests.java b/src/test/java/be/seeseemelk/llamascript/IntegerTests.java index 0bad1df..f517e29 100644 --- a/src/test/java/be/seeseemelk/llamascript/IntegerTests.java +++ b/src/test/java/be/seeseemelk/llamascript/IntegerTests.java @@ -8,34 +8,34 @@ import static org.hamcrest.Matchers.equalTo; public class IntegerTests extends AbstractLlamaTest { @Test void canIncrementNumber() { - int value = llama.eval("Increment the number", Integer.class, 4); + int value = llama.evalInt("Increment the number", 4); assertThat(value, equalTo(5)); } @Test void canMultiply() { - int value = llama.eval("Multiply the numbers", Integer.class, 2, 3); + int value = llama.evalInt("Multiply the numbers", 2, 3); assertThat(value, equalTo(2*3)); } @Test void canMax() { - int value = llama.eval("Select the largest number", Integer.class, -2, 8); + int value = llama.evalInt("Select the largest number", -2, 8); assertThat(value, equalTo(8)); } @Test void canDoWeirdStuff() { - int value = llama.eval("Select the number with the most '5's in it.", Integer.class, 12, 5, 1023978, 158525); + int value = llama.evalInt("Select the number with the most '5's in it.", 12, 5, 1023978, 158525); assertThat(value, equalTo(158525)); } @Test void givePositivity() { - int value = llama.eval("If the string is something positive, return 1. Else, return 0.", Integer.class, "I like rainbows"); + int value = llama.evalInt("If the string is something positive, return 1. Else, return 0.", "I like rainbows"); assertThat(value, equalTo(1)); - value = llama.eval("If the string is something positive, return 1. Else, return 0.", Integer.class, "Death to all"); + value = llama.evalInt("If the string is something positive, return 1. Else, return 0.", "Death to all"); assertThat(value, equalTo(0)); } } diff --git a/src/test/java/be/seeseemelk/llamascript/StringTests.java b/src/test/java/be/seeseemelk/llamascript/StringTests.java index e89ebb4..789ebd4 100644 --- a/src/test/java/be/seeseemelk/llamascript/StringTests.java +++ b/src/test/java/be/seeseemelk/llamascript/StringTests.java @@ -8,37 +8,37 @@ import static org.hamcrest.Matchers.equalTo; public class StringTests extends AbstractLlamaTest { @Test void canUppercase() { - String value = llama.eval("Return the string in uppercase", String.class, "Cat"); + String value = llama.evalString("Return the string in uppercase", "Cat"); assertThat(value, equalTo("CAT")); } @Test void canLowercase() { - String value = llama.eval("Return the string in lowercase", String.class, "Cat"); + String value = llama.evalString("Return the string in lowercase", "Cat"); assertThat(value, equalTo("cat")); } @Test void canInvert() { - String value = llama.eval("Return the string backwards", String.class, "Cat"); + String value = llama.evalString("Return the string backwards", "Cat"); assertThat(value, equalTo("taC")); } @Test void canSelectLongest() { - String value = llama.eval("Return the longest string", String.class, "Cat", "Dog", "Horse"); + String value = llama.evalString("Return the longest string", "Cat", "Dog", "Horse"); assertThat(value, equalTo("Horse")); } @Test void canDoWeirdStuff() { - String value = llama.eval("Return the string that does not fit", String.class, "Cat", "Dog", "Horse", "Bicycle"); + String value = llama.evalString("Return the string that does not fit", "Cat", "Dog", "Horse", "Bicycle"); assertThat(value, equalTo("Bicycle")); } @Test void canDoReallyWeirdStuff() { - String value = llama.eval("Sort the letters alphabetically", String.class, "horse"); + String value = llama.evalString("Sort the letters alphabetically", "horse"); assertThat(value, equalTo("ehors")); } }