Merge pull request #72 from daguava/minp-custom-options

Add minp option and ability to set custom options
This commit is contained in:
Amith Koujalgi 2024-10-27 20:05:24 +05:30 committed by GitHub
commit d949a3cb69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 0 deletions

View File

@ -1,5 +1,6 @@
package io.github.ollama4j.utils; package io.github.ollama4j.utils;
import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
/** Builder class for creating options for Ollama model. */ /** Builder class for creating options for Ollama model. */
@ -207,6 +208,34 @@ public class OptionsBuilder {
return this; return this;
} }
/**
* Alternative to the top_p, and aims to ensure a balance of qualityand variety. The parameter p
* represents the minimum probability for a token to be considered, relative to the probability
* of the most likely token. For example, with p=0.05 and the most likely token having a
* probability of 0.9, logits with a value less than 0.045 are filtered out. (Default: 0.0)
*/
public OptionsBuilder setMinP(float value) {
options.getOptionsMap().put("min_p", value);
return this;
}
/**
* Allows passing an option not formally supported by the library
* @param name The option name for the parameter.
* @param value The value for the "{name}" parameter.
* @return The updated OptionsBuilder.
* @throws IllegalArgumentException if parameter has an unsupported type
*/
public OptionsBuilder setCustomOption(String name, Object value) throws IllegalArgumentException {
if (!(value instanceof Integer || value instanceof Float || value instanceof String)) {
throw new IllegalArgumentException("Invalid type for parameter. Allowed types are: Integer, Float, or String.");
}
options.getOptionsMap().put(name, value);
return this;
}
/** /**
* Builds the options map. * Builds the options map.
* *
@ -215,4 +244,6 @@ public class OptionsBuilder {
public Options build() { public Options build() {
return options; return options;
} }
} }

View File

@ -1,6 +1,7 @@
package io.github.ollama4j.unittests.jackson; package io.github.ollama4j.unittests.jackson;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import java.io.File; import java.io.File;
import java.util.List; import java.util.List;
@ -59,6 +60,10 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
.withOptions(b.setSeed(1).build()) .withOptions(b.setSeed(1).build())
.withOptions(b.setTopK(1).build()) .withOptions(b.setTopK(1).build())
.withOptions(b.setTopP(1).build()) .withOptions(b.setTopP(1).build())
.withOptions(b.setMinP(1).build())
.withOptions(b.setCustomOption("cust_float", 1.0f).build())
.withOptions(b.setCustomOption("cust_int", 1).build())
.withOptions(b.setCustomOption("cust_str", "custom").build())
.build(); .build();
String jsonRequest = serialize(req); String jsonRequest = serialize(req);
@ -72,6 +77,20 @@ public class TestChatRequestSerialization extends AbstractSerializationTest<Olla
assertEquals(1, deserializeRequest.getOptions().get("seed")); assertEquals(1, deserializeRequest.getOptions().get("seed"));
assertEquals(1, deserializeRequest.getOptions().get("top_k")); assertEquals(1, deserializeRequest.getOptions().get("top_k"));
assertEquals(1.0, deserializeRequest.getOptions().get("top_p")); assertEquals(1.0, deserializeRequest.getOptions().get("top_p"));
assertEquals(1.0, deserializeRequest.getOptions().get("min_p"));
assertEquals(1.0, deserializeRequest.getOptions().get("cust_float"));
assertEquals(1, deserializeRequest.getOptions().get("cust_int"));
assertEquals("custom", deserializeRequest.getOptions().get("cust_str"));
}
@Test
public void testRequestWithInvalidCustomOption() {
OptionsBuilder b = new OptionsBuilder();
assertThrowsExactly(IllegalArgumentException.class, () -> {
OllamaChatRequest req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt")
.withOptions(b.setCustomOption("cust_obj", new Object()).build())
.build();
});
} }
@Test @Test