Some more stuff
This commit is contained in:
parent
1da1909642
commit
fa041e2027
@ -3,12 +3,26 @@ package be.seeseemelk.llamascript;
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
@Setter
|
||||
public class LLamaScriptFactory {
|
||||
private String host = "http://localhost:11434";
|
||||
private String model = "llama3.1:8b";
|
||||
private String host;
|
||||
private String model;
|
||||
|
||||
public LLamaScriptFactory() {
|
||||
host = Objects.requireNonNullElse(System.getenv("OLLAMA_HOST"), "http://localhost:11434");
|
||||
model = Objects.requireNonNullElse(System.getenv("OLLAMA_MODEL"), "llama3.1:8b");
|
||||
}
|
||||
|
||||
public LlamaScript build() {
|
||||
return new LlamaScript(new OllamaAPI(host), model);
|
||||
return new LlamaScript(buildWorker());
|
||||
}
|
||||
|
||||
private LlamaScriptWorker buildWorker() {
|
||||
return LlamaScriptWorker.builder()
|
||||
.api(new OllamaAPI(host))
|
||||
.model(model)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
@ -1,25 +1,14 @@
|
||||
package be.seeseemelk.llamascript;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class LlamaScript {
|
||||
private final OllamaAPI api;
|
||||
private final String model;
|
||||
private final ObjectMapper mapper = new ObjectMapper();
|
||||
private final LlamaScriptWorker worker;
|
||||
|
||||
public <T> T eval(String prompt, Class<T> returnType, Object... arguments) {
|
||||
var argumentMap = new HashMap<Integer, Object>();
|
||||
@ -33,13 +22,16 @@ public class LlamaScript {
|
||||
} else if (returnType.equals(String.class)) {
|
||||
//noinspection unchecked
|
||||
return (T) evalString(prompt, argumentMap);
|
||||
} else if (returnType.equals(Boolean.class)) {
|
||||
//noinspection unchecked
|
||||
return (T) evalBool(prompt, argumentMap);
|
||||
} else {
|
||||
throw new LLamaScriptException("Unsupported return type");
|
||||
}
|
||||
}
|
||||
|
||||
private String evalString(String prompt, HashMap<Integer, Object> argumentMap) {
|
||||
var result = eval(prompt, primitiveOutputType("a string"), argumentMap);
|
||||
public String evalString(String prompt, HashMap<Integer, Object> argumentMap) {
|
||||
var result = worker.eval(prompt, primitiveOutputType("a string"), argumentMap);
|
||||
if (result instanceof String string)
|
||||
return string;
|
||||
else if (result != null)
|
||||
@ -48,8 +40,12 @@ public class LlamaScript {
|
||||
throw new LLamaScriptException("No response");
|
||||
}
|
||||
|
||||
private Integer evalInt(String prompt, HashMap<Integer, Object> argumentMap) {
|
||||
var result = eval(prompt, primitiveOutputType("an integer number"), argumentMap);
|
||||
public String evalString(String prompt, Object... arguments) {
|
||||
return eval(prompt, String.class, arguments);
|
||||
}
|
||||
|
||||
public Integer evalInt(String prompt, HashMap<Integer, Object> argumentMap) {
|
||||
var result = worker.eval(prompt, primitiveOutputType("an integer number"), argumentMap);
|
||||
if (result instanceof Integer integer)
|
||||
return integer;
|
||||
else if (result instanceof String string) {
|
||||
@ -63,42 +59,22 @@ public class LlamaScript {
|
||||
throw new LLamaScriptException("Invalid integer type");
|
||||
}
|
||||
|
||||
private Object eval(String prompt, String outputType, HashMap<Integer, Object> argumentMap) {
|
||||
try {
|
||||
var optionsBuilder = new OptionsBuilder();
|
||||
optionsBuilder.setSeed(20); // Chosen with 1d20. Yes, a nat 20! Let's go baby!
|
||||
optionsBuilder.setTemperature(0f);
|
||||
var options = optionsBuilder.build();
|
||||
var fullPrompt = new StringBuilder("""
|
||||
You will receive a prompt, and will have to consider the question in the prompt with regards to the inputs you will receive.
|
||||
Your output must be a JSON object containing the single value 'result'.
|
||||
%s
|
||||
|
||||
The prompt: %s
|
||||
|
||||
Arguments:
|
||||
""".formatted(outputType, prompt));
|
||||
for (var entry : argumentMap.entrySet()) {
|
||||
fullPrompt.append(" - ").append(entry.getKey()).append(": ").append(entry.getValue()).append("\n");
|
||||
}
|
||||
System.out.printf("Prompt is: %s\n", fullPrompt);
|
||||
|
||||
var messages = new ArrayList<OllamaChatMessage>();
|
||||
messages.add(new OllamaChatMessage(OllamaChatMessageRole.USER, fullPrompt.toString()));
|
||||
var chatRequest = new OllamaChatRequest(model, messages);
|
||||
chatRequest.setReturnFormatJson(true);
|
||||
chatRequest.setOptions(options.getOptionsMap());
|
||||
|
||||
var result = api.chat(chatRequest);
|
||||
System.out.printf("Result is: %s\n", result.getResponse());
|
||||
return mapper.readValue(result.getResponse(), Result.class).result;
|
||||
} catch (OllamaBaseException | InterruptedException | IOException e) {
|
||||
throw new LLamaScriptException(e);
|
||||
}
|
||||
public int evalInt(String prompt, Object... arguments) {
|
||||
return eval(prompt, Integer.class, arguments);
|
||||
}
|
||||
|
||||
private static class Result {
|
||||
public Object result;
|
||||
public Boolean evalBool(String prompt, HashMap<Integer, Object> argumentMap) {
|
||||
var result = worker.eval(prompt, primitiveOutputType("a boolean value"), argumentMap);
|
||||
if (result instanceof Boolean bool)
|
||||
return bool;
|
||||
else if (result instanceof String string)
|
||||
return Boolean.parseBoolean(string);
|
||||
else
|
||||
throw new LLamaScriptException("Invalid boolean type");
|
||||
}
|
||||
|
||||
public boolean evalBool(String prompt, Object... arguments) {
|
||||
return eval(prompt, Boolean.class, arguments);
|
||||
}
|
||||
|
||||
private static String primitiveOutputType(String description) {
|
||||
|
@ -0,0 +1,59 @@
|
||||
package be.seeseemelk.llamascript;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.exceptions.OllamaBaseException;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequest;
|
||||
import io.github.ollama4j.utils.OptionsBuilder;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
|
||||
@Builder
|
||||
public class LlamaScriptWorker {
|
||||
private final OllamaAPI api;
|
||||
private final String model;
|
||||
private final ObjectMapper mapper = new ObjectMapper();
|
||||
|
||||
public Object eval(String prompt, String outputType, HashMap<Integer, Object> argumentMap) {
|
||||
try {
|
||||
var optionsBuilder = new OptionsBuilder();
|
||||
optionsBuilder.setSeed(20); // Chosen with 1d20. Yes, a nat 20! Let's go baby!
|
||||
optionsBuilder.setTemperature(0f);
|
||||
var options = optionsBuilder.build();
|
||||
var fullPrompt = new StringBuilder("""
|
||||
You will receive a prompt, and will have to consider the question in the prompt with regards to the inputs you will receive.
|
||||
Your output must be a JSON object containing the single value 'result'.
|
||||
%s
|
||||
|
||||
The prompt: %s
|
||||
|
||||
Arguments:
|
||||
""".formatted(outputType, prompt));
|
||||
for (var entry : argumentMap.entrySet()) {
|
||||
fullPrompt.append(" - ").append(entry.getKey()).append(": ").append(entry.getValue()).append("\n");
|
||||
}
|
||||
System.out.printf("Prompt is: %s\n", fullPrompt);
|
||||
|
||||
var messages = new ArrayList<OllamaChatMessage>();
|
||||
messages.add(new OllamaChatMessage(OllamaChatMessageRole.USER, fullPrompt.toString()));
|
||||
var chatRequest = new OllamaChatRequest(model, messages);
|
||||
chatRequest.setReturnFormatJson(true);
|
||||
chatRequest.setOptions(options.getOptionsMap());
|
||||
|
||||
var result = api.chat(chatRequest);
|
||||
System.out.printf("Result is: %s\n", result.getResponse());
|
||||
return mapper.readValue(result.getResponse(), Result.class).result;
|
||||
} catch (OllamaBaseException | InterruptedException | IOException e) {
|
||||
throw new LLamaScriptException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static class Result {
|
||||
public Object result;
|
||||
}
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
package be.seeseemelk.llamascript.arguments;
|
||||
|
||||
public abstract class ArgumentBag {
|
||||
}
|
26
src/test/java/be/seeseemelk/llamascript/BooleanTests.java
Normal file
26
src/test/java/be/seeseemelk/llamascript/BooleanTests.java
Normal file
@ -0,0 +1,26 @@
|
||||
package be.seeseemelk.llamascript;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class BooleanTests extends AbstractLlamaTest {
|
||||
@ParameterizedTest
|
||||
@ValueSource(ints = {0, 1, 2, 3, 4, 5})
|
||||
void isEven(int number) {
|
||||
boolean isEven = llama.evalBool("Return true if the number is even", number);
|
||||
assertThat(isEven, equalTo(number % 2 == 0));
|
||||
}
|
||||
|
||||
@Test
|
||||
void isABuilding() {
|
||||
boolean isBuilding = llama.evalBool("Return true if the argument is a building or skyscraper", "Dalai Lama");
|
||||
assertThat(isBuilding, equalTo(false));
|
||||
|
||||
isBuilding = llama.evalBool("Return true if the argument is a building or skyscraper", "Burj Khalifa");
|
||||
assertThat(isBuilding, equalTo(true));
|
||||
}
|
||||
}
|
@ -8,34 +8,34 @@ import static org.hamcrest.Matchers.equalTo;
|
||||
public class IntegerTests extends AbstractLlamaTest {
|
||||
@Test
|
||||
void canIncrementNumber() {
|
||||
int value = llama.eval("Increment the number", Integer.class, 4);
|
||||
int value = llama.evalInt("Increment the number", 4);
|
||||
assertThat(value, equalTo(5));
|
||||
}
|
||||
|
||||
@Test
|
||||
void canMultiply() {
|
||||
int value = llama.eval("Multiply the numbers", Integer.class, 2, 3);
|
||||
int value = llama.evalInt("Multiply the numbers", 2, 3);
|
||||
assertThat(value, equalTo(2*3));
|
||||
}
|
||||
|
||||
@Test
|
||||
void canMax() {
|
||||
int value = llama.eval("Select the largest number", Integer.class, -2, 8);
|
||||
int value = llama.evalInt("Select the largest number", -2, 8);
|
||||
assertThat(value, equalTo(8));
|
||||
}
|
||||
|
||||
@Test
|
||||
void canDoWeirdStuff() {
|
||||
int value = llama.eval("Select the number with the most '5's in it.", Integer.class, 12, 5, 1023978, 158525);
|
||||
int value = llama.evalInt("Select the number with the most '5's in it.", 12, 5, 1023978, 158525);
|
||||
assertThat(value, equalTo(158525));
|
||||
}
|
||||
|
||||
@Test
|
||||
void givePositivity() {
|
||||
int value = llama.eval("If the string is something positive, return 1. Else, return 0.", Integer.class, "I like rainbows");
|
||||
int value = llama.evalInt("If the string is something positive, return 1. Else, return 0.", "I like rainbows");
|
||||
assertThat(value, equalTo(1));
|
||||
|
||||
value = llama.eval("If the string is something positive, return 1. Else, return 0.", Integer.class, "Death to all");
|
||||
value = llama.evalInt("If the string is something positive, return 1. Else, return 0.", "Death to all");
|
||||
assertThat(value, equalTo(0));
|
||||
}
|
||||
}
|
||||
|
@ -8,37 +8,37 @@ import static org.hamcrest.Matchers.equalTo;
|
||||
public class StringTests extends AbstractLlamaTest {
|
||||
@Test
|
||||
void canUppercase() {
|
||||
String value = llama.eval("Return the string in uppercase", String.class, "Cat");
|
||||
String value = llama.evalString("Return the string in uppercase", "Cat");
|
||||
assertThat(value, equalTo("CAT"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void canLowercase() {
|
||||
String value = llama.eval("Return the string in lowercase", String.class, "Cat");
|
||||
String value = llama.evalString("Return the string in lowercase", "Cat");
|
||||
assertThat(value, equalTo("cat"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void canInvert() {
|
||||
String value = llama.eval("Return the string backwards", String.class, "Cat");
|
||||
String value = llama.evalString("Return the string backwards", "Cat");
|
||||
assertThat(value, equalTo("taC"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void canSelectLongest() {
|
||||
String value = llama.eval("Return the longest string", String.class, "Cat", "Dog", "Horse");
|
||||
String value = llama.evalString("Return the longest string", "Cat", "Dog", "Horse");
|
||||
assertThat(value, equalTo("Horse"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void canDoWeirdStuff() {
|
||||
String value = llama.eval("Return the string that does not fit", String.class, "Cat", "Dog", "Horse", "Bicycle");
|
||||
String value = llama.evalString("Return the string that does not fit", "Cat", "Dog", "Horse", "Bicycle");
|
||||
assertThat(value, equalTo("Bicycle"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void canDoReallyWeirdStuff() {
|
||||
String value = llama.eval("Sort the letters alphabetically", String.class, "horse");
|
||||
String value = llama.evalString("Sort the letters alphabetically", "horse");
|
||||
assertThat(value, equalTo("ehors"));
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user