Refactor token handler interfaces and improve streaming

Renamed and refactored token handler interfaces for chat and generate modules to improve clarity and separation. Updated related classes and method signatures to use new handler types. Enhanced error handling and logging in chat and generate request builders. Updated tests and integration code to use new handler classes and configuration properties. Suppressed verbose logs from Docker and Testcontainers in test logging configuration.
This commit is contained in:
amithkoujalgi
2025-09-19 18:05:38 +05:30
parent d118958ac1
commit cb0f71ba63
21 changed files with 216 additions and 231 deletions

View File

@@ -13,9 +13,12 @@ import static org.junit.jupiter.api.Assertions.*;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.OllamaBaseException;
import io.github.ollama4j.exceptions.ToolInvocationException;
import io.github.ollama4j.impl.ConsoleOutputChatTokenHandler;
import io.github.ollama4j.impl.ConsoleOutputGenerateTokenHandler;
import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
import io.github.ollama4j.models.response.Model;
import io.github.ollama4j.models.response.ModelDetail;
import io.github.ollama4j.models.response.OllamaResult;
@@ -56,10 +59,41 @@ class OllamaAPIIntegrationTest {
@BeforeAll
static void setUp() {
int requestTimeoutSeconds = 60;
int numberOfRetriesForModelPull = 5;
try {
boolean useExternalOllamaHost =
Boolean.parseBoolean(System.getenv("USE_EXTERNAL_OLLAMA_HOST"));
String ollamaHost = System.getenv("OLLAMA_HOST");
// Try to get from env vars first
String useExternalOllamaHostEnv = System.getenv("USE_EXTERNAL_OLLAMA_HOST");
String ollamaHostEnv = System.getenv("OLLAMA_HOST");
boolean useExternalOllamaHost;
String ollamaHost;
if (useExternalOllamaHostEnv == null && ollamaHostEnv == null) {
// Fallback to test-config.properties from classpath
Properties props = new Properties();
try {
props.load(
OllamaAPIIntegrationTest.class
.getClassLoader()
.getResourceAsStream("test-config.properties"));
} catch (Exception e) {
throw new RuntimeException(
"Could not load test-config.properties from classpath", e);
}
useExternalOllamaHost =
Boolean.parseBoolean(
props.getProperty("USE_EXTERNAL_OLLAMA_HOST", "false"));
ollamaHost = props.getProperty("OLLAMA_HOST");
requestTimeoutSeconds =
Integer.parseInt(props.getProperty("REQUEST_TIMEOUT_SECONDS"));
numberOfRetriesForModelPull =
Integer.parseInt(props.getProperty("NUMBER_RETRIES_FOR_MODEL_PULL"));
} else {
useExternalOllamaHost = Boolean.parseBoolean(useExternalOllamaHostEnv);
ollamaHost = ollamaHostEnv;
}
if (useExternalOllamaHost) {
LOG.info("Using external Ollama host...");
@@ -90,8 +124,8 @@ class OllamaAPIIntegrationTest {
+ ":"
+ ollama.getMappedPort(internalPort));
}
api.setRequestTimeoutSeconds(120);
api.setNumberOfRetriesForModelPull(5);
api.setRequestTimeoutSeconds(requestTimeoutSeconds);
api.setNumberOfRetriesForModelPull(numberOfRetriesForModelPull);
}
@Test
@@ -187,7 +221,7 @@ class OllamaAPIIntegrationTest {
});
format.put("required", List.of("isNoon"));
OllamaResult result = api.generate(TOOLS_MODEL, prompt, format);
OllamaResult result = api.generateWithFormat(TOOLS_MODEL, prompt, format);
assertNotNull(result);
assertNotNull(result.getResponse());
@@ -210,7 +244,8 @@ class OllamaAPIIntegrationTest {
+ " Lisa?",
raw,
thinking,
new OptionsBuilder().build());
new OptionsBuilder().build(),
new OllamaGenerateStreamObserver(null, null));
assertNotNull(result);
assertNotNull(result.getResponse());
assertFalse(result.getResponse().isEmpty());
@@ -228,8 +263,10 @@ class OllamaAPIIntegrationTest {
"What is the capital of France? And what's France's connection with Mona"
+ " Lisa?",
raw,
false,
new OptionsBuilder().build(),
LOG::info);
new OllamaGenerateStreamObserver(
null, new ConsoleOutputGenerateTokenHandler()));
assertNotNull(result);
assertNotNull(result.getResponse());
@@ -263,7 +300,7 @@ class OllamaAPIIntegrationTest {
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertFalse(chatResult.getResponseModel().getMessage().getContent().isEmpty());
assertFalse(chatResult.getResponseModel().getMessage().getResponse().isEmpty());
}
@Test
@@ -296,9 +333,13 @@ class OllamaAPIIntegrationTest {
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertFalse(chatResult.getResponseModel().getMessage().getContent().isBlank());
assertFalse(chatResult.getResponseModel().getMessage().getResponse().isBlank());
assertTrue(
chatResult.getResponseModel().getMessage().getContent().contains(expectedResponse));
chatResult
.getResponseModel()
.getMessage()
.getResponse()
.contains(expectedResponse));
assertEquals(3, chatResult.getChatHistory().size());
}
@@ -515,16 +556,7 @@ class OllamaAPIIntegrationTest {
.withOptions(new OptionsBuilder().setTemperature(0.9f).build())
.build();
OllamaChatResult chatResult =
api.chat(
requestModel,
new OllamaChatStreamObserver(
s -> {
LOG.info(s.toUpperCase());
},
s -> {
LOG.info(s.toLowerCase());
}));
OllamaChatResult chatResult = api.chat(requestModel, new ConsoleOutputChatTokenHandler());
assertNotNull(chatResult, "chatResult should not be null");
assertNotNull(chatResult.getResponseModel(), "Response model should not be null");
@@ -670,20 +702,11 @@ class OllamaAPIIntegrationTest {
.build();
requestModel.setThink(false);
OllamaChatResult chatResult =
api.chat(
requestModel,
new OllamaChatStreamObserver(
s -> {
LOG.info(s.toUpperCase());
},
s -> {
LOG.info(s.toLowerCase());
}));
OllamaChatResult chatResult = api.chat(requestModel, new ConsoleOutputChatTokenHandler());
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertNotNull(chatResult.getResponseModel().getMessage().getResponse());
}
@Test
@@ -706,21 +729,12 @@ class OllamaAPIIntegrationTest {
.withKeepAlive("0m")
.build();
OllamaChatResult chatResult =
api.chat(
requestModel,
new OllamaChatStreamObserver(
s -> {
LOG.info(s.toUpperCase());
},
s -> {
LOG.info(s.toLowerCase());
}));
OllamaChatResult chatResult = api.chat(requestModel, new ConsoleOutputChatTokenHandler());
assertNotNull(chatResult);
assertNotNull(chatResult.getResponseModel());
assertNotNull(chatResult.getResponseModel().getMessage());
assertNotNull(chatResult.getResponseModel().getMessage().getContent());
assertNotNull(chatResult.getResponseModel().getMessage().getResponse());
}
@Test
@@ -859,7 +873,8 @@ class OllamaAPIIntegrationTest {
"Who are you?",
raw,
think,
new OptionsBuilder().build());
new OptionsBuilder().build(),
new OllamaGenerateStreamObserver(null, null));
assertNotNull(result);
assertNotNull(result.getResponse());
assertNotNull(result.getThinking());
@@ -876,13 +891,15 @@ class OllamaAPIIntegrationTest {
THINKING_TOOL_MODEL,
"Who are you?",
raw,
true,
new OptionsBuilder().build(),
thinkingToken -> {
LOG.info(thinkingToken.toUpperCase());
},
resToken -> {
LOG.info(resToken.toLowerCase());
});
new OllamaGenerateStreamObserver(
thinkingToken -> {
LOG.info(thinkingToken.toUpperCase());
},
resToken -> {
LOG.info(resToken.toLowerCase());
}));
assertNotNull(result);
assertNotNull(result.getResponse());
assertNotNull(result.getThinking());

View File

@@ -203,7 +203,7 @@ public class WithAuth {
});
format.put("required", List.of("isNoon"));
OllamaResult result = api.generate(model, prompt, format);
OllamaResult result = api.generateWithFormat(model, prompt, format);
assertNotNull(result);
assertNotNull(result.getResponse());

View File

@@ -18,6 +18,7 @@ import io.github.ollama4j.exceptions.RoleNotFoundException;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.embeddings.OllamaEmbedRequestModel;
import io.github.ollama4j.models.embeddings.OllamaEmbedResponseModel;
import io.github.ollama4j.models.generate.OllamaGenerateStreamObserver;
import io.github.ollama4j.models.request.CustomModelRequest;
import io.github.ollama4j.models.response.ModelDetail;
import io.github.ollama4j.models.response.OllamaAsyncResultStreamer;
@@ -170,12 +171,13 @@ class TestMockedAPIs {
String model = "llama2";
String prompt = "some prompt text";
OptionsBuilder optionsBuilder = new OptionsBuilder();
OllamaGenerateStreamObserver observer = new OllamaGenerateStreamObserver(null, null);
try {
when(ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build()))
when(ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build(), observer))
.thenReturn(new OllamaResult("", "", 0, 200));
ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build());
ollamaAPI.generate(model, prompt, false, false, optionsBuilder.build(), observer);
verify(ollamaAPI, times(1))
.generate(model, prompt, false, false, optionsBuilder.build());
.generate(model, prompt, false, false, optionsBuilder.build(), observer);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}

View File

@@ -59,6 +59,6 @@ class TestOllamaChatRequestBuilder {
assertNotNull(req.getMessages());
assert (!req.getMessages().isEmpty());
OllamaChatMessage msg = req.getMessages().get(0);
assertNotNull(msg.getContent());
assertNotNull(msg.getResponse());
}
}

View File

@@ -67,12 +67,9 @@ class TestOptionsAndUtils {
@Test
void testOptionsBuilderRejectsUnsupportedCustomType() {
OptionsBuilder builder = new OptionsBuilder();
assertThrows(
IllegalArgumentException.class,
() -> {
OptionsBuilder builder = new OptionsBuilder();
builder.setCustomOption("bad", new Object());
});
IllegalArgumentException.class, () -> builder.setCustomOption("bad", new Object()));
}
@Test

View File

@@ -10,6 +10,14 @@
<appender-ref ref="STDOUT"/>
</root>
<!-- Suppress logs from com.github.dockerjava package -->
<logger name="com.github.dockerjava" level="INFO"/>
<!-- Suppress logs from org.testcontainers package -->
<logger name="org.testcontainers" level="INFO"/>
<!-- Keep other loggers at WARN level -->
<logger name="org.apache" level="WARN"/>
<logger name="httpclient" level="WARN"/>
</configuration>

View File

@@ -1,4 +1,4 @@
ollama.url=http://localhost:11434
ollama.model=llama3.2:1b
ollama.model.image=llava:latest
ollama.request-timeout-seconds=120
USE_EXTERNAL_OLLAMA_HOST=true
OLLAMA_HOST=http://192.168.29.229:11434/
REQUEST_TIMEOUT_SECONDS=120
NUMBER_RETRIES_FOR_MODEL_PULL=3