From fa041e202748c67d4ab5c5a19f59720794305a9b Mon Sep 17 00:00:00 2001
From: Sebastiaan de Schaetzen <sebastiaan.de.schaetzen@gmail.com>
Date: Tue, 20 Aug 2024 18:31:52 +0200
Subject: [PATCH] Some more stuff

---
 .../llamascript/LLamaScriptFactory.java       | 20 ++++-
 .../seeseemelk/llamascript/LlamaScript.java   | 76 +++++++------------
 .../llamascript/LlamaScriptWorker.java        | 59 ++++++++++++++
 .../llamascript/arguments/ArgumentBag.java    |  4 +
 .../seeseemelk/llamascript/BooleanTests.java  | 26 +++++++
 .../seeseemelk/llamascript/IntegerTests.java  | 12 +--
 .../seeseemelk/llamascript/StringTests.java   | 12 +--
 7 files changed, 144 insertions(+), 65 deletions(-)
 create mode 100644 src/main/java/be/seeseemelk/llamascript/LlamaScriptWorker.java
 create mode 100644 src/main/java/be/seeseemelk/llamascript/arguments/ArgumentBag.java
 create mode 100644 src/test/java/be/seeseemelk/llamascript/BooleanTests.java

diff --git a/src/main/java/be/seeseemelk/llamascript/LLamaScriptFactory.java b/src/main/java/be/seeseemelk/llamascript/LLamaScriptFactory.java
index b40f3d4..5dcffea 100644
--- a/src/main/java/be/seeseemelk/llamascript/LLamaScriptFactory.java
+++ b/src/main/java/be/seeseemelk/llamascript/LLamaScriptFactory.java
@@ -3,12 +3,26 @@ package be.seeseemelk.llamascript;
 import io.github.ollama4j.OllamaAPI;
 import lombok.Setter;
 
+import java.util.Objects;
+
 @Setter
 public class LLamaScriptFactory {
-    private String host = "http://localhost:11434";
-    private String model = "llama3.1:8b";
+    private String host;
+    private String model;
+
+    public LLamaScriptFactory() {
+        host = Objects.requireNonNullElse(System.getenv("OLLAMA_HOST"), "http://localhost:11434");
+        model = Objects.requireNonNullElse(System.getenv("OLLAMA_MODEL"), "llama3.1:8b");
+    }
 
     public LlamaScript build() {
-        return new LlamaScript(new OllamaAPI(host), model);
+        return new LlamaScript(buildWorker());
+    }
+
+    private LlamaScriptWorker buildWorker() {
+        return LlamaScriptWorker.builder()
+                .api(new OllamaAPI(host))
+                .model(model)
+                .build();
     }
 }
diff --git a/src/main/java/be/seeseemelk/llamascript/LlamaScript.java b/src/main/java/be/seeseemelk/llamascript/LlamaScript.java
index 30b1f79..7f6d8f1 100644
--- a/src/main/java/be/seeseemelk/llamascript/LlamaScript.java
+++ b/src/main/java/be/seeseemelk/llamascript/LlamaScript.java
@@ -1,25 +1,14 @@
 package be.seeseemelk.llamascript;
 
-import com.fasterxml.jackson.databind.ObjectMapper;
-import io.github.ollama4j.OllamaAPI;
-import io.github.ollama4j.exceptions.OllamaBaseException;
-import io.github.ollama4j.models.chat.OllamaChatMessage;
-import io.github.ollama4j.models.chat.OllamaChatMessageRole;
-import io.github.ollama4j.models.chat.OllamaChatRequest;
-import io.github.ollama4j.utils.OptionsBuilder;
 import lombok.RequiredArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 
-import java.io.IOException;
-import java.util.ArrayList;
 import java.util.HashMap;
 
 @Slf4j
 @RequiredArgsConstructor
 public class LlamaScript {
-    private final OllamaAPI api;
-    private final String model;
-    private final ObjectMapper mapper = new ObjectMapper();
+    private final LlamaScriptWorker worker;
 
     public <T> T eval(String prompt, Class<T> returnType, Object... arguments) {
         var argumentMap = new HashMap<Integer, Object>();
@@ -33,13 +22,16 @@ public class LlamaScript {
         } else if (returnType.equals(String.class)) {
             //noinspection unchecked
             return (T) evalString(prompt, argumentMap);
+        } else if (returnType.equals(Boolean.class)) {
+            //noinspection unchecked
+            return (T) evalBool(prompt, argumentMap);
         } else {
             throw new LLamaScriptException("Unsupported return type");
         }
     }
 
-    private String evalString(String prompt, HashMap<Integer, Object> argumentMap) {
-        var result = eval(prompt, primitiveOutputType("a string"), argumentMap);
+    public String evalString(String prompt, HashMap<Integer, Object> argumentMap) {
+        var result = worker.eval(prompt, primitiveOutputType("a string"), argumentMap);
         if (result instanceof String string)
             return string;
         else if (result != null)
@@ -48,8 +40,12 @@ public class LlamaScript {
             throw new LLamaScriptException("No response");
     }
 
-    private Integer evalInt(String prompt, HashMap<Integer, Object> argumentMap) {
-        var result = eval(prompt, primitiveOutputType("an integer number"), argumentMap);
+    public String evalString(String prompt, Object... arguments) {
+        return eval(prompt, String.class, arguments);
+    }
+
+    public Integer evalInt(String prompt, HashMap<Integer, Object> argumentMap) {
+        var result = worker.eval(prompt, primitiveOutputType("an integer number"), argumentMap);
         if (result instanceof Integer integer)
             return integer;
         else if (result instanceof String string) {
@@ -63,42 +59,22 @@ public class LlamaScript {
             throw new LLamaScriptException("Invalid integer type");
     }
 
-    private Object eval(String prompt, String outputType, HashMap<Integer, Object> argumentMap) {
-        try {
-            var optionsBuilder = new OptionsBuilder();
-            optionsBuilder.setSeed(20); // Chosen with 1d20. Yes, a nat 20! Let's go baby!
-            optionsBuilder.setTemperature(0f);
-            var options = optionsBuilder.build();
-            var fullPrompt = new StringBuilder("""
-                You will receive a prompt, and will have to consider the question in the prompt with regards to the inputs you will receive.
-                Your output must be a JSON object containing the single value 'result'.
-                %s
-                
-                The prompt: %s
-                
-                Arguments:
-                """.formatted(outputType, prompt));
-            for (var entry : argumentMap.entrySet()) {
-                fullPrompt.append(" - ").append(entry.getKey()).append(": ").append(entry.getValue()).append("\n");
-            }
-            System.out.printf("Prompt is: %s\n", fullPrompt);
-
-            var messages = new ArrayList<OllamaChatMessage>();
-            messages.add(new OllamaChatMessage(OllamaChatMessageRole.USER, fullPrompt.toString()));
-            var chatRequest = new OllamaChatRequest(model, messages);
-            chatRequest.setReturnFormatJson(true);
-            chatRequest.setOptions(options.getOptionsMap());
-
-            var result = api.chat(chatRequest);
-            System.out.printf("Result is: %s\n", result.getResponse());
-            return mapper.readValue(result.getResponse(), Result.class).result;
-        } catch (OllamaBaseException | InterruptedException | IOException e) {
-            throw new LLamaScriptException(e);
-        }
+    public int evalInt(String prompt, Object... arguments) {
+        return eval(prompt, Integer.class, arguments);
     }
 
-    private static class Result {
-        public Object result;
+    public Boolean evalBool(String prompt, HashMap<Integer, Object> argumentMap) {
+        var result = worker.eval(prompt, primitiveOutputType("a boolean value"), argumentMap);
+        if (result instanceof Boolean bool)
+            return bool;
+        else if (result instanceof String string)
+            return Boolean.parseBoolean(string);
+        else
+            throw new LLamaScriptException("Invalid boolean type");
+    }
+
+    public boolean evalBool(String prompt, Object... arguments) {
+        return eval(prompt, Boolean.class, arguments);
     }
 
     private static String primitiveOutputType(String description) {
diff --git a/src/main/java/be/seeseemelk/llamascript/LlamaScriptWorker.java b/src/main/java/be/seeseemelk/llamascript/LlamaScriptWorker.java
new file mode 100644
index 0000000..7021a2e
--- /dev/null
+++ b/src/main/java/be/seeseemelk/llamascript/LlamaScriptWorker.java
@@ -0,0 +1,59 @@
+package be.seeseemelk.llamascript;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.github.ollama4j.OllamaAPI;
+import io.github.ollama4j.exceptions.OllamaBaseException;
+import io.github.ollama4j.models.chat.OllamaChatMessage;
+import io.github.ollama4j.models.chat.OllamaChatMessageRole;
+import io.github.ollama4j.models.chat.OllamaChatRequest;
+import io.github.ollama4j.utils.OptionsBuilder;
+import lombok.Builder;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+
+@Builder
+public class LlamaScriptWorker {
+    private final OllamaAPI api;
+    private final String model;
+    private final ObjectMapper mapper = new ObjectMapper();
+
+    public Object eval(String prompt, String outputType, HashMap<Integer, Object> argumentMap) {
+        try {
+            var optionsBuilder = new OptionsBuilder();
+            optionsBuilder.setSeed(20); // Chosen with 1d20. Yes, a nat 20! Let's go baby!
+            optionsBuilder.setTemperature(0f);
+            var options = optionsBuilder.build();
+            var fullPrompt = new StringBuilder("""
+                You will receive a prompt, and will have to consider the question in the prompt with regards to the inputs you will receive.
+                Your output must be a JSON object containing the single value 'result'.
+                %s
+                
+                The prompt: %s
+                
+                Arguments:
+                """.formatted(outputType, prompt));
+            for (var entry : argumentMap.entrySet()) {
+                fullPrompt.append(" - ").append(entry.getKey()).append(": ").append(entry.getValue()).append("\n");
+            }
+            System.out.printf("Prompt is: %s\n", fullPrompt);
+
+            var messages = new ArrayList<OllamaChatMessage>();
+            messages.add(new OllamaChatMessage(OllamaChatMessageRole.USER, fullPrompt.toString()));
+            var chatRequest = new OllamaChatRequest(model, messages);
+            chatRequest.setReturnFormatJson(true);
+            chatRequest.setOptions(options.getOptionsMap());
+
+            var result = api.chat(chatRequest);
+            System.out.printf("Result is: %s\n", result.getResponse());
+            return mapper.readValue(result.getResponse(), Result.class).result;
+        } catch (OllamaBaseException | InterruptedException | IOException e) {
+            throw new LLamaScriptException(e);
+        }
+    }
+
+    private static class Result {
+        public Object result;
+    }
+}
diff --git a/src/main/java/be/seeseemelk/llamascript/arguments/ArgumentBag.java b/src/main/java/be/seeseemelk/llamascript/arguments/ArgumentBag.java
new file mode 100644
index 0000000..8ab82b9
--- /dev/null
+++ b/src/main/java/be/seeseemelk/llamascript/arguments/ArgumentBag.java
@@ -0,0 +1,4 @@
+package be.seeseemelk.llamascript.arguments;
+
+public abstract class ArgumentBag {
+}
diff --git a/src/test/java/be/seeseemelk/llamascript/BooleanTests.java b/src/test/java/be/seeseemelk/llamascript/BooleanTests.java
new file mode 100644
index 0000000..3462822
--- /dev/null
+++ b/src/test/java/be/seeseemelk/llamascript/BooleanTests.java
@@ -0,0 +1,26 @@
+package be.seeseemelk.llamascript;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+
+public class BooleanTests extends AbstractLlamaTest {
+    @ParameterizedTest
+    @ValueSource(ints = {0, 1, 2, 3, 4, 5})
+    void isEven(int number) {
+        boolean isEven = llama.evalBool("Return true if the number is even", number);
+        assertThat(isEven, equalTo(number % 2 == 0));
+    }
+
+    @Test
+    void isABuilding() {
+        boolean isBuilding = llama.evalBool("Return true if the argument is a building or skyscraper", "Dalai Lama");
+        assertThat(isBuilding, equalTo(false));
+
+        isBuilding = llama.evalBool("Return true if the argument is a building or skyscraper", "Burj Khalifa");
+        assertThat(isBuilding, equalTo(true));
+    }
+}
diff --git a/src/test/java/be/seeseemelk/llamascript/IntegerTests.java b/src/test/java/be/seeseemelk/llamascript/IntegerTests.java
index 0bad1df..f517e29 100644
--- a/src/test/java/be/seeseemelk/llamascript/IntegerTests.java
+++ b/src/test/java/be/seeseemelk/llamascript/IntegerTests.java
@@ -8,34 +8,34 @@ import static org.hamcrest.Matchers.equalTo;
 public class IntegerTests extends AbstractLlamaTest {
     @Test
     void canIncrementNumber() {
-        int value = llama.eval("Increment the number", Integer.class, 4);
+        int value = llama.evalInt("Increment the number", 4);
         assertThat(value, equalTo(5));
     }
 
     @Test
     void canMultiply() {
-        int value = llama.eval("Multiply the numbers", Integer.class, 2, 3);
+        int value = llama.evalInt("Multiply the numbers", 2, 3);
         assertThat(value, equalTo(2*3));
     }
 
     @Test
     void canMax() {
-        int value = llama.eval("Select the largest number", Integer.class, -2, 8);
+        int value = llama.evalInt("Select the largest number", -2, 8);
         assertThat(value, equalTo(8));
     }
 
     @Test
     void canDoWeirdStuff() {
-        int value = llama.eval("Select the number with the most '5's in it.", Integer.class, 12, 5, 1023978, 158525);
+        int value = llama.evalInt("Select the number with the most '5's in it.", 12, 5, 1023978, 158525);
         assertThat(value, equalTo(158525));
     }
 
     @Test
     void givePositivity() {
-        int value = llama.eval("If the string is something positive, return 1. Else, return 0.", Integer.class, "I like rainbows");
+        int value = llama.evalInt("If the string is something positive, return 1. Else, return 0.", "I like rainbows");
         assertThat(value, equalTo(1));
 
-        value = llama.eval("If the string is something positive, return 1. Else, return 0.", Integer.class, "Death to all");
+        value = llama.evalInt("If the string is something positive, return 1. Else, return 0.", "Death to all");
         assertThat(value, equalTo(0));
     }
 }
diff --git a/src/test/java/be/seeseemelk/llamascript/StringTests.java b/src/test/java/be/seeseemelk/llamascript/StringTests.java
index e89ebb4..789ebd4 100644
--- a/src/test/java/be/seeseemelk/llamascript/StringTests.java
+++ b/src/test/java/be/seeseemelk/llamascript/StringTests.java
@@ -8,37 +8,37 @@ import static org.hamcrest.Matchers.equalTo;
 public class StringTests extends AbstractLlamaTest {
     @Test
     void canUppercase() {
-        String value = llama.eval("Return the string in uppercase", String.class, "Cat");
+        String value = llama.evalString("Return the string in uppercase", "Cat");
         assertThat(value, equalTo("CAT"));
     }
 
     @Test
     void canLowercase() {
-        String value = llama.eval("Return the string in lowercase", String.class, "Cat");
+        String value = llama.evalString("Return the string in lowercase", "Cat");
         assertThat(value, equalTo("cat"));
     }
 
     @Test
     void canInvert() {
-        String value = llama.eval("Return the string backwards", String.class, "Cat");
+        String value = llama.evalString("Return the string backwards", "Cat");
         assertThat(value, equalTo("taC"));
     }
 
     @Test
     void canSelectLongest() {
-        String value = llama.eval("Return the longest string", String.class, "Cat", "Dog", "Horse");
+        String value = llama.evalString("Return the longest string", "Cat", "Dog", "Horse");
         assertThat(value, equalTo("Horse"));
     }
 
     @Test
     void canDoWeirdStuff() {
-        String value = llama.eval("Return the string that does not fit", String.class, "Cat", "Dog", "Horse", "Bicycle");
+        String value = llama.evalString("Return the string that does not fit", "Cat", "Dog", "Horse", "Bicycle");
         assertThat(value, equalTo("Bicycle"));
     }
 
     @Test
     void canDoReallyWeirdStuff() {
-        String value = llama.eval("Sort the letters alphabetically", String.class, "horse");
+        String value = llama.evalString("Sort the letters alphabetically", "horse");
         assertThat(value, equalTo("ehors"));
     }
 }