diff --git a/docs/src/components/TypewriterTextarea/index.js b/docs/src/components/TypewriterTextarea/index.js index 78eb0a3..740d79e 100644 --- a/docs/src/components/TypewriterTextarea/index.js +++ b/docs/src/components/TypewriterTextarea/index.js @@ -1,6 +1,6 @@ import React, { useEffect, useState, useRef } from 'react'; -const TypewriterTextarea = ({ textContent, typingSpeed = 50, pauseBetweenSentences = 1000, height = '200px', width = '100%' }) => { +const TypewriterTextarea = ({ textContent, typingSpeed = 50, pauseBetweenSentences = 1000, height = '200px', width = '100%', align = 'left' }) => { const [text, setText] = useState(''); const [sentenceIndex, setSentenceIndex] = useState(0); const [charIndex, setCharIndex] = useState(0); @@ -56,11 +56,13 @@ const TypewriterTextarea = ({ textContent, typingSpeed = 50, pauseBetweenSentenc fontSize: '1rem', backgroundColor: '#f4f4f4', border: '1px solid #ccc', + textAlign: align, resize: 'none', whiteSpace: 'pre-wrap', + color: 'black', }} /> ); }; -export default TypewriterTextarea; +export default TypewriterTextarea; \ No newline at end of file diff --git a/docs/src/pages/index.js b/docs/src/pages/index.js index 2f9a649..12d5c55 100644 --- a/docs/src/pages/index.js +++ b/docs/src/pages/index.js @@ -32,6 +32,7 @@ function HomepageHeader() { pauseBetweenSentences={1200} height='130px' width='100%' + align='center' />
diff --git a/pom.xml b/pom.xml index 6121533..087ca96 100644 --- a/pom.xml +++ b/pom.xml @@ -223,6 +223,12 @@ 1.20.2 test + + org.testcontainers + nginx + 1.20.0 + test + diff --git a/src/main/java/io/github/ollama4j/OllamaAPI.java b/src/main/java/io/github/ollama4j/OllamaAPI.java index c837575..5689faa 100644 --- a/src/main/java/io/github/ollama4j/OllamaAPI.java +++ b/src/main/java/io/github/ollama4j/OllamaAPI.java @@ -51,7 +51,7 @@ import java.util.stream.Collectors; /** * The base Ollama API class. */ -@SuppressWarnings({ "DuplicatedCode", "resource" }) +@SuppressWarnings({"DuplicatedCode", "resource"}) public class OllamaAPI { private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); @@ -215,7 +215,7 @@ public class OllamaAPI { * tags, tag count, and the time when model was updated. * * @return A list of {@link LibraryModel} objects representing the models - * available in the Ollama library. + * available in the Ollama library. * @throws OllamaBaseException If the HTTP request fails or the response is not * successful (non-200 status code). * @throws IOException If an I/O error occurs during the HTTP request @@ -281,7 +281,7 @@ public class OllamaAPI { * of the library model * for which the tags need to be fetched. * @return a list of {@link LibraryModelTag} objects containing the extracted - * tags and their associated metadata. + * tags and their associated metadata. * @throws OllamaBaseException if the HTTP response status code indicates an * error (i.e., not 200 OK), * or if there is any other issue during the @@ -348,7 +348,7 @@ public class OllamaAPI { * @param modelName The name of the model to search for in the library. * @param tag The tag name to search for within the specified model. * @return The {@link LibraryModelTag} associated with the specified model and - * tag. + * tag. * @throws OllamaBaseException If there is a problem with the Ollama library * operations. * @throws IOException If an I/O error occurs during the operation. @@ -778,7 +778,7 @@ public class OllamaAPI { * @param format A map containing the format specification for the structured * output. * @return An instance of {@link OllamaResult} containing the structured - * response. + * response. * @throws OllamaBaseException if the response indicates an error status. * @throws IOException if an I/O error occurs during the HTTP request. * @throws InterruptedException if the operation is interrupted. @@ -796,8 +796,9 @@ public class OllamaAPI { String jsonData = Utils.getObjectMapper().writeValueAsString(requestBody); HttpClient httpClient = HttpClient.newHttpClient(); - HttpRequest request = HttpRequest.newBuilder(uri) - .header("Content-Type", "application/json") + HttpRequest request = getRequestBuilderDefault(uri) + .header("Accept", "application/json") + .header("Content-type", "application/json") .POST(HttpRequest.BodyPublishers.ofString(jsonData)) .build(); @@ -852,8 +853,8 @@ public class OllamaAPI { * @param options Additional options or configurations to use when generating * the response. * @return {@link OllamaToolsResult} An OllamaToolsResult object containing the - * response from the AI model and the results of invoking the tools on - * that output. + * response from the AI model and the results of invoking the tools on + * that output. * @throws OllamaBaseException if the response indicates an error status * @throws IOException if an I/O error occurs during the HTTP request * @throws InterruptedException if the operation is interrupted @@ -1065,14 +1066,14 @@ public class OllamaAPI { * @param model the ollama model to ask the question to * @param messages chat history / message stack to send to the model * @return {@link OllamaChatResult} containing the api response and the message - * history including the newly acquired assistant response. - * @throws OllamaBaseException any response code than 200 has been returned - * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network - * issues happen - * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request - * @throws InterruptedException if the operation is interrupted + * history including the newly acquired assistant response. + * @throws OllamaBaseException any response code than 200 has been returned + * @throws IOException in case the responseStream can not be read + * @throws InterruptedException in case the server is not reachable or network + * issues happen + * @throws OllamaBaseException if the response indicates an error status + * @throws IOException if an I/O error occurs during the HTTP request + * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ public OllamaChatResult chat(String model, List messages) @@ -1089,13 +1090,13 @@ public class OllamaAPI { * * @param request request object to be sent to the server * @return {@link OllamaChatResult} - * @throws OllamaBaseException any response code than 200 has been returned - * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network - * issues happen - * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request - * @throws InterruptedException if the operation is interrupted + * @throws OllamaBaseException any response code than 200 has been returned + * @throws IOException in case the responseStream can not be read + * @throws InterruptedException in case the server is not reachable or network + * issues happen + * @throws OllamaBaseException if the response indicates an error status + * @throws IOException if an I/O error occurs during the HTTP request + * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ public OllamaChatResult chat(OllamaChatRequest request) @@ -1114,13 +1115,13 @@ public class OllamaAPI { * (caution: all previous tokens from stream will be * concatenated) * @return {@link OllamaChatResult} - * @throws OllamaBaseException any response code than 200 has been returned - * @throws IOException in case the responseStream can not be read - * @throws InterruptedException in case the server is not reachable or network - * issues happen - * @throws OllamaBaseException if the response indicates an error status - * @throws IOException if an I/O error occurs during the HTTP request - * @throws InterruptedException if the operation is interrupted + * @throws OllamaBaseException any response code than 200 has been returned + * @throws IOException in case the responseStream can not be read + * @throws InterruptedException in case the server is not reachable or network + * issues happen + * @throws OllamaBaseException if the response indicates an error status + * @throws IOException if an I/O error occurs during the HTTP request + * @throws InterruptedException if the operation is interrupted * @throws ToolInvocationException if the tool invocation fails */ public OllamaChatResult chat(OllamaChatRequest request, OllamaStreamHandler streamHandler) diff --git a/src/main/java/io/github/ollama4j/tools/sampletools/WeatherTool.java b/src/main/java/io/github/ollama4j/tools/sampletools/WeatherTool.java new file mode 100644 index 0000000..eb0ba72 --- /dev/null +++ b/src/main/java/io/github/ollama4j/tools/sampletools/WeatherTool.java @@ -0,0 +1,88 @@ +package io.github.ollama4j.tools.sampletools; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.github.ollama4j.tools.Tools; + +public class WeatherTool { + private String openWeatherMapAPIKey = null; + + public WeatherTool(String openWeatherMapAPIKey) { + this.openWeatherMapAPIKey = openWeatherMapAPIKey; + } + + public String getCurrentWeather(Map arguments) { + String city = (String) arguments.get("cityName"); + System.out.println("Finding weather for city: " + city); + + String url = String.format("https://api.openweathermap.org/data/2.5/weather?q=%s&appid=%s&units=metric", + city, + this.openWeatherMapAPIKey); + + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(url)) + .build(); + try { + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + if (response.statusCode() == 200) { + ObjectMapper mapper = new ObjectMapper(); + JsonNode root = mapper.readTree(response.body()); + JsonNode main = root.path("main"); + double temperature = main.path("temp").asDouble(); + String description = root.path("weather").get(0).path("description").asText(); + return String.format("Weather in %s: %.1f°C, %s", city, temperature, description); + } else { + return "Could not retrieve weather data for " + city + ". Status code: " + + response.statusCode(); + } + } catch (IOException | InterruptedException e) { + e.printStackTrace(); + return "Error retrieving weather data: " + e.getMessage(); + } + } + + public Tools.ToolSpecification getSpecification() { + return Tools.ToolSpecification.builder() + .functionName("weather-reporter") + .functionDescription( + "You are a tool who simply finds the city name from the user's message input/query about weather.") + .toolFunction(this::getCurrentWeather) + .toolPrompt( + Tools.PromptFuncDefinition.builder() + .type("prompt") + .function( + Tools.PromptFuncDefinition.PromptFuncSpec + .builder() + .name("get-city-name") + .description("Get the city name") + .parameters( + Tools.PromptFuncDefinition.Parameters + .builder() + .type("object") + .properties( + Map.of( + "cityName", + Tools.PromptFuncDefinition.Property + .builder() + .type("string") + .description( + "The name of the city. e.g. Bengaluru") + .required(true) + .build())) + .required(java.util.List + .of("cityName")) + .build()) + .build()) + .build()) + .build(); + } +} diff --git a/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java b/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java new file mode 100644 index 0000000..6531b27 --- /dev/null +++ b/src/test/java/io/github/ollama4j/integrationtests/WithAuth.java @@ -0,0 +1,194 @@ +package io.github.ollama4j.integrationtests; + +import io.github.ollama4j.OllamaAPI; +import io.github.ollama4j.exceptions.OllamaBaseException; +import io.github.ollama4j.models.response.OllamaResult; +import io.github.ollama4j.samples.AnnotatedTool; +import io.github.ollama4j.tools.annotations.OllamaToolService; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestMethodOrder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.NginxContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.ollama.OllamaContainer; +import org.testcontainers.utility.DockerImageName; +import org.testcontainers.utility.MountableFile; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.net.URISyntaxException; +import java.time.Duration; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +@OllamaToolService(providers = {AnnotatedTool.class}) +@TestMethodOrder(OrderAnnotation.class) +@SuppressWarnings({"HttpUrlsUsage", "SpellCheckingInspection", "resource", "ResultOfMethodCallIgnored"}) +public class WithAuth { + + private static final Logger LOG = LoggerFactory.getLogger(WithAuth.class); + private static final int NGINX_PORT = 80; + private static final int OLLAMA_INTERNAL_PORT = 11434; + private static final String OLLAMA_VERSION = "0.6.1"; + private static final String NGINX_VERSION = "nginx:1.23.4-alpine"; + private static final String BEARER_AUTH_TOKEN = "secret-token"; + private static final String CHAT_MODEL_LLAMA3 = "llama3"; + + + private static OllamaContainer ollama; + private static GenericContainer nginx; + private static OllamaAPI api; + + @BeforeAll + public static void setUp() { + ollama = createOllamaContainer(); + ollama.start(); + + nginx = createNginxContainer(ollama.getMappedPort(OLLAMA_INTERNAL_PORT)); + nginx.start(); + + LOG.info("Using Testcontainer Ollama host..."); + + api = new OllamaAPI("http://" + nginx.getHost() + ":" + nginx.getMappedPort(NGINX_PORT)); + api.setRequestTimeoutSeconds(120); + api.setVerbose(true); + api.setNumberOfRetriesForModelPull(3); + + String ollamaUrl = "http://" + ollama.getHost() + ":" + ollama.getMappedPort(OLLAMA_INTERNAL_PORT); + String nginxUrl = "http://" + nginx.getHost() + ":" + nginx.getMappedPort(NGINX_PORT); + LOG.info( + "The Ollama service is now accessible via the Nginx proxy with bearer-auth authentication mode.\n" + + "→ Ollama URL: {}\n" + + "→ Proxy URL: {}}", + ollamaUrl, nginxUrl + ); + LOG.info("OllamaAPI initialized with bearer auth token: {}", BEARER_AUTH_TOKEN); + } + + private static OllamaContainer createOllamaContainer() { + return new OllamaContainer("ollama/ollama:" + OLLAMA_VERSION).withExposedPorts(OLLAMA_INTERNAL_PORT); + } + + private static String generateNginxConfig(int ollamaPort) { + return String.format("events {}\n" + + "\n" + + "http {\n" + + " server {\n" + + " listen 80;\n" + + "\n" + + " location / {\n" + + " set $auth_header $http_authorization;\n" + + "\n" + + " if ($auth_header != \"Bearer secret-token\") {\n" + + " return 401;\n" + + " }\n" + + "\n" + + " proxy_pass http://host.docker.internal:%s/;\n" + + " proxy_set_header Host $host;\n" + + " proxy_set_header X-Real-IP $remote_addr;\n" + + " proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;\n" + + " proxy_set_header X-Forwarded-Proto $scheme;\n" + + " }\n" + + " }\n" + + "}\n", ollamaPort); + } + + public static GenericContainer createNginxContainer(int ollamaPort) { + File nginxConf; + try { + File tempDir = new File(System.getProperty("java.io.tmpdir"), "nginx-auth"); + if (!tempDir.exists()) tempDir.mkdirs(); + + nginxConf = new File(tempDir, "nginx.conf"); + try (FileWriter writer = new FileWriter(nginxConf)) { + writer.write(generateNginxConfig(ollamaPort)); + } + + return new NginxContainer<>(DockerImageName.parse(NGINX_VERSION)) + .withExposedPorts(NGINX_PORT) + .withCopyFileToContainer( + MountableFile.forHostPath(nginxConf.getAbsolutePath()), + "/etc/nginx/nginx.conf" + ) + .withExtraHost("host.docker.internal", "host-gateway") + .waitingFor( + Wait.forHttp("/") + .forStatusCode(401) + .withStartupTimeout(Duration.ofSeconds(30)) + ); + } catch (IOException e) { + throw new RuntimeException("Failed to create nginx.conf", e); + } + } + + @Test + @Order(1) + void testOllamaBehindProxy() throws InterruptedException { + api.setBearerAuth(BEARER_AUTH_TOKEN); + assertTrue(api.ping(), "Expected OllamaAPI to successfully ping through NGINX with valid auth token."); + } + + @Test + @Order(1) + void testWithWrongToken() throws InterruptedException { + api.setBearerAuth("wrong-token"); + assertFalse(api.ping(), "Expected OllamaAPI ping to fail through NGINX with an invalid auth token."); + } + + @Test + @Order(2) + void testAskModelWithStructuredOutput() + throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { + api.setBearerAuth(BEARER_AUTH_TOKEN); + + api.pullModel(CHAT_MODEL_LLAMA3); + + int timeHour = 6; + boolean isNightTime = false; + + String prompt = "The Sun is shining, and its " + timeHour + ". Its daytime."; + + Map format = new HashMap<>(); + format.put("type", "object"); + format.put("properties", new HashMap() { + { + put("timeHour", new HashMap() { + { + put("type", "integer"); + } + }); + put("isNightTime", new HashMap() { + { + put("type", "boolean"); + } + }); + } + }); + format.put("required", Arrays.asList("timeHour", "isNightTime")); + + OllamaResult result = api.generate(CHAT_MODEL_LLAMA3, prompt, format); + + assertNotNull(result); + assertNotNull(result.getResponse()); + assertFalse(result.getResponse().isEmpty()); + + assertEquals(timeHour, + result.getStructuredResponse().get("timeHour")); + assertEquals(isNightTime, + result.getStructuredResponse().get("isNightTime")); + + TimeOfDay timeOfDay = result.as(TimeOfDay.class); + + assertEquals(timeHour, timeOfDay.getTimeHour()); + assertEquals(isNightTime, timeOfDay.isNightTime()); + } +}