diff --git a/docs/docs/apis-ask/ask.md b/docs/docs/apis-ask/ask.md index e0d5918..9e71947 100644 --- a/docs/docs/apis-ask/ask.md +++ b/docs/docs/apis-ask/ask.md @@ -8,6 +8,11 @@ This API lets you ask questions to the LLMs in a synchronous way. These APIs correlate to the [completion](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion) APIs. +Use the `OptionBuilder` to build the `Options` object +with [extra parameters](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). +Refer +to [this](/docs/apis-extras/options-builder). + ## Try asking a question about the model. ```java @@ -19,11 +24,13 @@ public class Main { OllamaAPI ollamaAPI = new OllamaAPI(host); - OllamaResult result = ollamaAPI.ask(OllamaModelType.LLAMA2, "Who are you?"); + OllamaResult result = + ollamaAPI.ask(OllamaModelType.LLAMA2, "Who are you?", new OptionsBuilder().build()); System.out.println(result.getResponse()); } } + ``` You will get a response similar to: @@ -47,11 +54,13 @@ public class Main { String prompt = "List all cricket world cup teams of 2019."; - OllamaResult result = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt); + OllamaResult result = + ollamaAPI.ask(OllamaModelType.LLAMA2, prompt, new OptionsBuilder().build()); System.out.println(result.getResponse()); } } + ``` You'd then get a response from the model: @@ -84,12 +93,15 @@ public class Main { String host = "http://localhost:11434/"; OllamaAPI ollamaAPI = new OllamaAPI(host); - String prompt = SamplePrompts.getSampleDatabasePromptWithQuestion( - "List all customer names who have bought one or more products"); - OllamaResult result = ollamaAPI.ask(OllamaModelType.SQLCODER, prompt); + String prompt = + SamplePrompts.getSampleDatabasePromptWithQuestion( + "List all customer names who have bought one or more products"); + OllamaResult result = + ollamaAPI.ask(OllamaModelType.SQLCODER, prompt, new OptionsBuilder().build()); System.out.println(result.getResponse()); } } + ``` _Note: Here I've used diff --git a/docs/docs/apis-extras/options-builder.md b/docs/docs/apis-extras/options-builder.md new file mode 100644 index 0000000..d92511d --- /dev/null +++ b/docs/docs/apis-extras/options-builder.md @@ -0,0 +1,53 @@ +--- +sidebar_position: 1 +--- + +# Options Builder + +This lets you build options for the `ask()` API. +Check out the supported +options [here](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). + +## Build an empty Options object + +```java +import io.github.amithkoujalgi.ollama4j.core.utils.Options; +import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; + +public class Main { + + public static void main(String[] args) { + + String host = "http://localhost:11434/"; + + OllamaAPI ollamaAPI = new OllamaAPI(host); + + Options options = new OptionsBuilder().build(); + } +} +``` + +## Build an empty Options object + +```java +import io.github.amithkoujalgi.ollama4j.core.utils.Options; +import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; + +public class Main { + + public static void main(String[] args) { + + String host = "http://localhost:11434/"; + + OllamaAPI ollamaAPI = new OllamaAPI(host); + + Options options = + new OptionsBuilder() + .setMirostat(10) + .setMirostatEta(0.5f) + .setNumGpu(2) + .setTemperature(1.5f) + .build(); + } +} +``` \ No newline at end of file diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java index d1c6cee..c61c55f 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -6,6 +6,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileConte import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest; import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest; +import io.github.amithkoujalgi.ollama4j.core.utils.Options; import io.github.amithkoujalgi.ollama4j.core.utils.Utils; import java.io.BufferedReader; import java.io.ByteArrayOutputStream; @@ -332,11 +333,15 @@ public class OllamaAPI { * * @param model the ollama model to ask the question to * @param prompt the prompt/question text + * @param options the Options object - More + * details on the options * @return OllamaResult that includes response text and time taken for response */ - public OllamaResult ask(String model, String prompt) + public OllamaResult ask(String model, String prompt, Options options) throws OllamaBaseException, IOException, InterruptedException { OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); + ollamaRequestModel.setOptions(options.getOptionsMap()); return askSync(ollamaRequestModel); } diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java index 043f341..a2507a6 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java @@ -1,10 +1,10 @@ package io.github.amithkoujalgi.ollama4j.core.models; - import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper; import com.fasterxml.jackson.core.JsonProcessingException; import java.util.List; +import java.util.Map; import lombok.Data; @Data @@ -13,6 +13,7 @@ public class OllamaRequestModel { private String model; private String prompt; private List images; + private Map options; public OllamaRequestModel(String model, String prompt) { this.model = model; diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Options.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Options.java new file mode 100644 index 0000000..2339969 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/Options.java @@ -0,0 +1,11 @@ +package io.github.amithkoujalgi.ollama4j.core.utils; + +import java.util.Map; +import lombok.Data; + +/** Class for options for Ollama model. */ +@Data +public class Options { + + private final Map optionsMap; +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/OptionsBuilder.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/OptionsBuilder.java new file mode 100644 index 0000000..d605f81 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/OptionsBuilder.java @@ -0,0 +1,218 @@ +package io.github.amithkoujalgi.ollama4j.core.utils; + +import java.util.HashMap; + +/** Builder class for creating options for Ollama model. */ +public class OptionsBuilder { + + private final Options options; + + /** Constructs a new OptionsBuilder with an empty options map. */ + public OptionsBuilder() { + this.options = new Options(new HashMap<>()); + } + + /** + * Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 + * = Mirostat 2.0) + * + * @param value The value for the "mirostat" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setMirostat(int value) { + options.getOptionsMap().put("mirostat", value); + return this; + } + + /** + * Influences how quickly the algorithm responds to feedback from the generated text. A lower + * learning rate will result in slower adjustments, while a higher learning rate will make the + * algorithm more responsive. (Default: 0.1) + * + * @param value The value for the "mirostat_eta" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setMirostatEta(float value) { + options.getOptionsMap().put("mirostat_eta", value); + return this; + } + + /** + * Controls the balance between coherence and diversity of the output. A lower value will result + * in more focused and coherent text. (Default: 5.0) + * + * @param value The value for the "mirostat_tau" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setMirostatTau(float value) { + options.getOptionsMap().put("mirostat_tau", value); + return this; + } + + /** + * Sets the size of the context window used to generate the next token. (Default: 2048) + * + * @param value The value for the "num_ctx" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setNumCtx(int value) { + options.getOptionsMap().put("num_ctx", value); + return this; + } + + /** + * The number of GQA groups in the transformer layer. Required for some models, for example, it is + * 8 for llama2:70b. + * + * @param value The value for the "num_gqa" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setNumGqa(int value) { + options.getOptionsMap().put("num_gqa", value); + return this; + } + + /** + * The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, + * 0 to disable. + * + * @param value The value for the "num_gpu" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setNumGpu(int value) { + options.getOptionsMap().put("num_gpu", value); + return this; + } + + /** + * Sets the number of threads to use during computation. By default, Ollama will detect this for + * optimal performance. It is recommended to set this value to the number of physical CPU cores + * your system has (as opposed to the logical number of cores). + * + * @param value The value for the "num_thread" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setNumThread(int value) { + options.getOptionsMap().put("num_thread", value); + return this; + } + + /** + * Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, + * -1 = num_ctx) + * + * @param value The value for the "repeat_last_n" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setRepeatLastN(int value) { + options.getOptionsMap().put("repeat_last_n", value); + return this; + } + + /** + * Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions + * more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) + * + * @param value The value for the "repeat_penalty" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setRepeatPenalty(float value) { + options.getOptionsMap().put("repeat_penalty", value); + return this; + } + + /** + * The temperature of the model. Increasing the temperature will make the model answer more + * creatively. (Default: 0.8) + * + * @param value The value for the "temperature" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setTemperature(float value) { + options.getOptionsMap().put("temperature", value); + return this; + } + + /** + * Sets the random number seed to use for generation. Setting this to a specific number will make + * the model generate the same text for the same prompt. (Default: 0) + * + * @param value The value for the "seed" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setSeed(int value) { + options.getOptionsMap().put("seed", value); + return this; + } + + /** + * Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating + * text and return. Multiple stop patterns may be set by specifying multiple separate `stop` + * parameters in a modelfile. + * + * @param value The value for the "stop" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setStop(String value) { + options.getOptionsMap().put("stop", value); + return this; + } + + /** + * Tail free sampling is used to reduce the impact of less probable tokens from the output. A + * higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this + * setting. (default: 1) + * + * @param value The value for the "tfs_z" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setTfsZ(float value) { + options.getOptionsMap().put("tfs_z", value); + return this; + } + + /** + * Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite + * generation, -2 = fill context) + * + * @param value The value for the "num_predict" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setNumPredict(int value) { + options.getOptionsMap().put("num_predict", value); + return this; + } + + /** + * Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more + * diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) + * + * @param value The value for the "top_k" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setTopK(int value) { + options.getOptionsMap().put("top_k", value); + return this; + } + + /** + * Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a + * lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) + * + * @param value The value for the "top_p" parameter. + * @return The updated OptionsBuilder. + */ + public OptionsBuilder setTopP(float value) { + options.getOptionsMap().put("top_p", value); + return this; + } + + /** + * Builds the options map. + * + * @return The populated options map. + */ + public Options build() { + return options; + } +} diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java index b7c9977..49b9eee 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/unittests/TestMockedAPIs.java @@ -8,6 +8,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback; import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; +import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder; import java.io.IOException; import java.net.URISyntaxException; import java.util.ArrayList; @@ -100,10 +101,12 @@ class TestMockedAPIs { OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); String model = OllamaModelType.LLAMA2; String prompt = "some prompt text"; + OptionsBuilder optionsBuilder = new OptionsBuilder(); try { - when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0, 200)); - ollamaAPI.ask(model, prompt); - verify(ollamaAPI, times(1)).ask(model, prompt); + when(ollamaAPI.ask(model, prompt, optionsBuilder.build())) + .thenReturn(new OllamaResult("", 0, 200)); + ollamaAPI.ask(model, prompt, optionsBuilder.build()); + verify(ollamaAPI, times(1)).ask(model, prompt, optionsBuilder.build()); } catch (IOException | OllamaBaseException | InterruptedException e) { throw new RuntimeException(e); }