diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/Main.java b/src/main/java/io/github/amithkoujalgi/ollama4j/Main.java new file mode 100644 index 0000000..d62dc76 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/Main.java @@ -0,0 +1,20 @@ +package io.github.amithkoujalgi.ollama4j; + +import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; +import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; +import io.github.amithkoujalgi.ollama4j.core.utils.SamplePrompts; + +public class Main { + public static void main(String[] args) throws Exception { + String host = "http://localhost:11434/"; + OllamaAPI ollamaAPI = new OllamaAPI(host); + + String prompt1 = SamplePrompts.getSampleDatabasePromptWithQuestion("List all customer names who have bought one or more products"); + String response1 = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt1); + System.out.println(response1); + + String prompt2 = "Give me a list of world cup cricket teams."; + String response2 = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt2); + System.out.println(response2); + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaAPI.java deleted file mode 100644 index 74b42f2..0000000 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaAPI.java +++ /dev/null @@ -1,102 +0,0 @@ -package io.github.amithkoujalgi.ollama4j; - -import com.google.gson.Gson; -import org.apache.hc.client5.http.classic.methods.HttpPost; -import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; -import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; -import org.apache.hc.client5.http.impl.classic.HttpClients; -import org.apache.hc.core5.http.HttpEntity; -import org.apache.hc.core5.http.ParseException; -import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.hc.core5.http.io.entity.StringEntity; - -import java.io.BufferedReader; -import java.io.DataOutputStream; -import java.io.IOException; -import java.io.InputStreamReader; -import java.net.HttpURLConnection; -import java.net.URL; - -@SuppressWarnings("deprecation") -public class OllamaAPI { - private final String host; - - public OllamaAPI(String host) { - if (host.endsWith("/")) { - this.host = host.substring(0, host.length() - 1); - } else { - this.host = host; - } - } - - public void pullModel(OllamaModel model) throws IOException, ParseException, OllamaBaseException { - String url = this.host + "/api/pull"; - String jsonData = String.format("{\"name\": \"%s\"}", model.getModel()); - final HttpPost httpPost = new HttpPost(url); - final StringEntity entity = new StringEntity(jsonData); - httpPost.setEntity(entity); - httpPost.setHeader("Accept", "application/json"); - httpPost.setHeader("Content-type", "application/json"); - try (CloseableHttpClient client = HttpClients.createDefault(); - CloseableHttpResponse response = client.execute(httpPost)) { - final int statusCode = response.getCode(); - HttpEntity responseEntity = response.getEntity(); - String responseString = ""; - if (responseEntity != null) { - responseString = EntityUtils.toString(responseEntity, "UTF-8"); - } - if (statusCode == 200) { - System.out.println(responseString); - } else { - throw new OllamaBaseException(statusCode + " - " + responseString); - } - } - } - - public String runSync(OllamaModel ollamaModel, String promptText) throws OllamaBaseException, IOException { - Gson gson = new Gson(); - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModel.getModel(), promptText); - URL obj = new URL(this.host + "/api/generate"); - HttpURLConnection con = (HttpURLConnection) obj.openConnection(); - con.setRequestMethod("POST"); - con.setDoOutput(true); - con.setRequestProperty("Content-Type", "application/json"); - try (DataOutputStream wr = new DataOutputStream(con.getOutputStream())) { - wr.writeBytes(ollamaRequestModel.toString()); - } - int responseCode = con.getResponseCode(); - if (responseCode == HttpURLConnection.HTTP_OK) { - try (BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream()))) { - String inputLine; - StringBuilder response = new StringBuilder(); - while ((inputLine = in.readLine()) != null) { - OllamaResponseModel ollamaResponseModel = gson.fromJson(inputLine, OllamaResponseModel.class); - if (!ollamaResponseModel.getDone()) { - response.append(ollamaResponseModel.getResponse()); - } - System.out.println("Streamed response line: " + ollamaResponseModel.getResponse()); - } - in.close(); - return response.toString(); - } - } else { - throw new OllamaBaseException(con.getResponseCode() + " - " + con.getResponseMessage()); - } - } - - public OllamaAsyncResultCallback runAsync(OllamaModel ollamaModel, String promptText) throws OllamaBaseException, IOException { - OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModel.getModel(), promptText); - URL obj = new URL(this.host + "/api/generate"); - HttpURLConnection con = (HttpURLConnection) obj.openConnection(); - con.setRequestMethod("POST"); - con.setDoOutput(true); - con.setRequestProperty("Content-Type", "application/json"); - try (DataOutputStream wr = new DataOutputStream(con.getOutputStream())) { - wr.writeBytes(ollamaRequestModel.toString()); - } - OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(con); - ollamaAsyncResultCallback.start(); - return ollamaAsyncResultCallback; - } -} - diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaModel.java deleted file mode 100644 index ada052e..0000000 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaModel.java +++ /dev/null @@ -1,15 +0,0 @@ -package io.github.amithkoujalgi.ollama4j; - -public enum OllamaModel { - LLAMA2("llama2"), MISTRAL("mistral"), MEDLLAMA2("medllama2"), CODELLAMA("codellama"), VICUNA("vicuna"), ORCAMINI("orca-mini"), SQLCODER("sqlcoder"), WIZARDMATH("wizard-math"); - - private final String model; - - OllamaModel(String model) { - this.model = model; - } - - public String getModel() { - return model; - } -} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java new file mode 100644 index 0000000..6b28b97 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/OllamaAPI.java @@ -0,0 +1,199 @@ +package io.github.amithkoujalgi.ollama4j.core; + +import com.google.gson.Gson; +import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; +import io.github.amithkoujalgi.ollama4j.core.models.*; +import org.apache.hc.client5.http.classic.methods.HttpDelete; +import org.apache.hc.client5.http.classic.methods.HttpGet; +import org.apache.hc.client5.http.classic.methods.HttpPost; +import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; +import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; +import org.apache.hc.client5.http.impl.classic.HttpClients; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.io.entity.StringEntity; + +import java.io.BufferedReader; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.List; +import java.util.stream.Collectors; + +@SuppressWarnings({"DuplicatedCode", "ExtractMethodRecommender"}) +public class OllamaAPI { + + + private final String host; + + public OllamaAPI(String host) { + if (host.endsWith("/")) { + this.host = host.substring(0, host.length() - 1); + } else { + this.host = host; + } + } + + public List listModels() throws IOException, OllamaBaseException, ParseException { + String url = this.host + "/api/tags"; + final HttpGet httpGet = new HttpGet(url); + httpGet.setHeader("Accept", "application/json"); + httpGet.setHeader("Content-type", "application/json"); + try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpGet)) { + final int statusCode = response.getCode(); + HttpEntity responseEntity = response.getEntity(); + String responseString = ""; + if (responseEntity != null) { + responseString = EntityUtils.toString(responseEntity, "UTF-8"); + } + if (statusCode == 200) { + Models m = new Gson().fromJson(responseString, Models.class); + return m.getModels(); + } else { + throw new OllamaBaseException(statusCode + " - " + responseString); + } + } + } + + public ModelDetail getModelDetails(Model model) throws IOException, OllamaBaseException, ParseException { + String url = this.host + "/api/show"; + String jsonData = String.format("{\"name\": \"%s\"}", model.getName()); + final HttpPost httpPost = new HttpPost(url); + final StringEntity entity = new StringEntity(jsonData); + httpPost.setEntity(entity); + httpPost.setHeader("Accept", "application/json"); + httpPost.setHeader("Content-type", "application/json"); + try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) { + final int statusCode = response.getCode(); + HttpEntity responseEntity = response.getEntity(); + String responseString = ""; + if (responseEntity != null) { + responseString = EntityUtils.toString(responseEntity, "UTF-8"); + } + if (statusCode == 200) { + return new Gson().fromJson(responseString, ModelDetail.class); + } else { + throw new OllamaBaseException(statusCode + " - " + responseString); + } + } + } + + public void pullModel(String model) throws IOException, ParseException, OllamaBaseException { + List models = listModels().stream().filter(m -> m.getModelName().split(":")[0].equals(model)).collect(Collectors.toList()); + if (!models.isEmpty()) { + return; + } + String url = this.host + "/api/pull"; + String jsonData = String.format("{\"name\": \"%s\"}", model); + final HttpPost httpPost = new HttpPost(url); + final StringEntity entity = new StringEntity(jsonData); + httpPost.setEntity(entity); + httpPost.setHeader("Accept", "application/json"); + httpPost.setHeader("Content-type", "application/json"); + try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) { + final int statusCode = response.getCode(); + HttpEntity responseEntity = response.getEntity(); + String responseString = ""; + if (responseEntity != null) { + responseString = EntityUtils.toString(responseEntity, "UTF-8"); + } + if (statusCode != 200) { + throw new OllamaBaseException(statusCode + " - " + responseString); + } + } + } + + public void createModel(String name, String modelFilePath) throws IOException, ParseException, OllamaBaseException { + String url = this.host + "/api/create"; + String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", name, modelFilePath); + final HttpPost httpPost = new HttpPost(url); + final StringEntity entity = new StringEntity(jsonData); + httpPost.setEntity(entity); + httpPost.setHeader("Accept", "application/json"); + httpPost.setHeader("Content-type", "application/json"); + try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) { + final int statusCode = response.getCode(); + HttpEntity responseEntity = response.getEntity(); + String responseString = ""; + if (responseEntity != null) { + responseString = EntityUtils.toString(responseEntity, "UTF-8"); + } + if (statusCode != 200) { + throw new OllamaBaseException(statusCode + " - " + responseString); + } + } + } + + public void deleteModel(String name, boolean ignoreIfNotPresent) throws IOException, ParseException, OllamaBaseException { + String url = this.host + "/api/delete"; + String jsonData = String.format("{\"name\": \"%s\"}", name); + final HttpDelete httpDelete = new HttpDelete(url); + final StringEntity entity = new StringEntity(jsonData); + httpDelete.setEntity(entity); + httpDelete.setHeader("Accept", "application/json"); + httpDelete.setHeader("Content-type", "application/json"); + try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpDelete)) { + final int statusCode = response.getCode(); + HttpEntity responseEntity = response.getEntity(); + String responseString = ""; + if (responseEntity != null) { + responseString = EntityUtils.toString(responseEntity, "UTF-8"); + } + if (statusCode == 404 && responseString.contains("model") && responseString.contains("not found")) { + return; + } + if (statusCode != 200) { + throw new OllamaBaseException(statusCode + " - " + responseString); + } + } + } + + + public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException { + Gson gson = new Gson(); + OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); + URL obj = new URL(this.host + "/api/generate"); + HttpURLConnection con = (HttpURLConnection) obj.openConnection(); + con.setRequestMethod("POST"); + con.setDoOutput(true); + con.setRequestProperty("Content-Type", "application/json"); + try (DataOutputStream wr = new DataOutputStream(con.getOutputStream())) { + wr.writeBytes(ollamaRequestModel.toString()); + } + int responseCode = con.getResponseCode(); + if (responseCode == HttpURLConnection.HTTP_OK) { + try (BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream()))) { + String inputLine; + StringBuilder response = new StringBuilder(); + while ((inputLine = in.readLine()) != null) { + OllamaResponseModel ollamaResponseModel = gson.fromJson(inputLine, OllamaResponseModel.class); + if (!ollamaResponseModel.getDone()) { + response.append(ollamaResponseModel.getResponse()); + } + } + in.close(); + return response.toString(); + } + } else { + throw new OllamaBaseException(con.getResponseCode() + " - " + con.getResponseMessage()); + } + } + + public OllamaAsyncResultCallback askAsync(String ollamaModelType, String promptText) throws IOException { + OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); + URL obj = new URL(this.host + "/api/generate"); + HttpURLConnection con = (HttpURLConnection) obj.openConnection(); + con.setRequestMethod("POST"); + con.setDoOutput(true); + con.setRequestProperty("Content-Type", "application/json"); + try (DataOutputStream wr = new DataOutputStream(con.getOutputStream())) { + wr.writeBytes(ollamaRequestModel.toString()); + } + OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(con); + ollamaAsyncResultCallback.start(); + return ollamaAsyncResultCallback; + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaBaseException.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/exceptions/OllamaBaseException.java similarity index 68% rename from src/main/java/io/github/amithkoujalgi/ollama4j/OllamaBaseException.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/exceptions/OllamaBaseException.java index 91c41c4..7c8612f 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaBaseException.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/exceptions/OllamaBaseException.java @@ -1,4 +1,4 @@ -package io.github.amithkoujalgi.ollama4j; +package io.github.amithkoujalgi.ollama4j.core.exceptions; public class OllamaBaseException extends Exception { diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java new file mode 100644 index 0000000..6f4493e --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Model.java @@ -0,0 +1,47 @@ +package io.github.amithkoujalgi.ollama4j.core.models; + +public class Model { + private String name, modified_at, digest; + private Long size; + + public String getName() { + return name; + } + + public String getModelName() { + return name.split(":")[0]; + } + + public String getModelVersion() { + return name.split(":")[1]; + } + + public void setName(String name) { + this.name = name; + } + + public String getModified_at() { + return modified_at; + } + + public void setModified_at(String modified_at) { + this.modified_at = modified_at; + } + + public String getDigest() { + return digest; + } + + public void setDigest(String digest) { + this.digest = digest; + } + + public Long getSize() { + return size; + } + + public void setSize(Long size) { + this.size = size; + } + +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java new file mode 100644 index 0000000..201ebed --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/ModelDetail.java @@ -0,0 +1,37 @@ +package io.github.amithkoujalgi.ollama4j.core.models; + +public class ModelDetail { + private String license, modelfile, parameters, template; + + public String getLicense() { + return license; + } + + public void setLicense(String license) { + this.license = license; + } + + public String getModelfile() { + return modelfile; + } + + public void setModelfile(String modelfile) { + this.modelfile = modelfile; + } + + public String getParameters() { + return parameters; + } + + public void setParameters(String parameters) { + this.parameters = parameters; + } + + public String getTemplate() { + return template; + } + + public void setTemplate(String template) { + this.template = template; + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Models.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Models.java new file mode 100644 index 0000000..8520d90 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/Models.java @@ -0,0 +1,15 @@ +package io.github.amithkoujalgi.ollama4j.core.models; + +import java.util.List; + +public class Models { + private List models; + + public List getModels() { + return models; + } + + public void setModels(List models) { + this.models = models; + } +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaAsyncResultCallback.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java similarity index 70% rename from src/main/java/io/github/amithkoujalgi/ollama4j/OllamaAsyncResultCallback.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java index f4d8e5d..0354e3c 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaAsyncResultCallback.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaAsyncResultCallback.java @@ -1,6 +1,7 @@ -package io.github.amithkoujalgi.ollama4j; +package io.github.amithkoujalgi.ollama4j.core.models; import com.google.gson.Gson; +import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import java.io.BufferedReader; import java.io.IOException; @@ -25,22 +26,26 @@ public class OllamaAsyncResultCallback extends Thread { try { responseCode = this.connection.getResponseCode(); if (responseCode == HttpURLConnection.HTTP_OK) { - try (BufferedReader in = new BufferedReader(new InputStreamReader(this.connection.getInputStream()))) { + try (BufferedReader in = + new BufferedReader(new InputStreamReader(this.connection.getInputStream()))) { String inputLine; StringBuilder response = new StringBuilder(); while ((inputLine = in.readLine()) != null) { - OllamaResponseModel ollamaResponseModel = gson.fromJson(inputLine, OllamaResponseModel.class); + OllamaResponseModel ollamaResponseModel = + gson.fromJson(inputLine, OllamaResponseModel.class); if (!ollamaResponseModel.getDone()) { response.append(ollamaResponseModel.getResponse()); } -// System.out.println("Streamed response line: " + responseModel.getResponse()); + // System.out.println("Streamed response line: " + + // responseModel.getResponse()); } in.close(); this.isDone = true; this.result = response.toString(); } } else { - throw new OllamaBaseException(connection.getResponseCode() + " - " + connection.getResponseMessage()); + throw new OllamaBaseException( + connection.getResponseCode() + " - " + connection.getResponseMessage()); } } catch (IOException | OllamaBaseException e) { this.isDone = true; diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaRequestModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java similarity index 91% rename from src/main/java/io/github/amithkoujalgi/ollama4j/OllamaRequestModel.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java index 6924e7a..0f1454f 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaRequestModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaRequestModel.java @@ -1,4 +1,4 @@ -package io.github.amithkoujalgi.ollama4j; +package io.github.amithkoujalgi.ollama4j.core.models; import com.google.gson.Gson; diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaResponseModel.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java similarity index 92% rename from src/main/java/io/github/amithkoujalgi/ollama4j/OllamaResponseModel.java rename to src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java index b77f17c..082d890 100644 --- a/src/main/java/io/github/amithkoujalgi/ollama4j/OllamaResponseModel.java +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/models/OllamaResponseModel.java @@ -1,9 +1,8 @@ -package io.github.amithkoujalgi.ollama4j; +package io.github.amithkoujalgi.ollama4j.core.models; import java.util.List; -public -class OllamaResponseModel { +public class OllamaResponseModel { private String model; private String created_at; private String response; diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java new file mode 100644 index 0000000..d4dbbb2 --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/types/OllamaModelType.java @@ -0,0 +1,12 @@ +package io.github.amithkoujalgi.ollama4j.core.types; + +public class OllamaModelType { + public static String LLAMA2 = "llama2"; + public static String MISTRAL = "mistral"; + public static String MEDLLAMA2 = "medllama2"; + public static String CODELLAMA = "codellama"; + public static String VICUNA = "vicuna"; + public static String ORCAMINI = "orca-mini"; + public static String SQLCODER = "sqlcoder"; + public static String WIZARDMATH = "wizard-math"; +} diff --git a/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/SamplePrompts.java b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/SamplePrompts.java new file mode 100644 index 0000000..1e5dfdc --- /dev/null +++ b/src/main/java/io/github/amithkoujalgi/ollama4j/core/utils/SamplePrompts.java @@ -0,0 +1,25 @@ +package io.github.amithkoujalgi.ollama4j.core.utils; + +import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; + +import java.io.InputStream; +import java.util.Scanner; + +public class SamplePrompts { + public static String getSampleDatabasePromptWithQuestion(String question) throws Exception { + ClassLoader classLoader = OllamaAPI.class.getClassLoader(); + InputStream inputStream = classLoader.getResourceAsStream("sample-db-prompt-template.txt"); + if (inputStream != null) { + Scanner scanner = new Scanner(inputStream); + StringBuilder stringBuffer = new StringBuilder(); + while (scanner.hasNextLine()) { + stringBuffer.append(scanner.nextLine()).append("\n"); + } + scanner.close(); + return stringBuffer.toString().replaceAll("", question); + } else { + throw new Exception("Sample database question file not found."); + } + } + +} diff --git a/src/main/resources/sample-db-prompt-template.txt b/src/main/resources/sample-db-prompt-template.txt new file mode 100644 index 0000000..177f648 --- /dev/null +++ b/src/main/resources/sample-db-prompt-template.txt @@ -0,0 +1,61 @@ +""" +Following is the database schema. + +DROP TABLE IF EXISTS product_categories; +CREATE TABLE IF NOT EXISTS product_categories +( + category_id INTEGER PRIMARY KEY, -- Unique ID for each category + name VARCHAR(50), -- Name of the category + parent INTEGER NULL, -- Parent category - for hierarchical categories + FOREIGN KEY (parent) REFERENCES product_categories (category_id) +); +DROP TABLE IF EXISTS products; +CREATE TABLE IF NOT EXISTS products +( + product_id INTEGER PRIMARY KEY, -- Unique ID for each product + name VARCHAR(50), -- Name of the product + price DECIMAL(10, 2), -- Price of each unit of the product + quantity INTEGER, -- Current quantity in stock + category_id INTEGER, -- Unique ID for each product + FOREIGN KEY (category_id) REFERENCES product_categories (category_id) +); +DROP TABLE IF EXISTS customers; +CREATE TABLE IF NOT EXISTS customers +( + customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer + name VARCHAR(50), -- Name of the customer + address VARCHAR(100) -- Mailing address of the customer +); +DROP TABLE IF EXISTS salespeople; +CREATE TABLE IF NOT EXISTS salespeople +( + salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson + name VARCHAR(50), -- Name of the salesperson + region VARCHAR(50) -- Geographic sales region +); +DROP TABLE IF EXISTS sales; +CREATE TABLE IF NOT EXISTS sales +( + sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale + product_id INTEGER, -- ID of product sold + customer_id INTEGER, -- ID of customer who made the purchase + salesperson_id INTEGER, -- ID of salesperson who made the sale + sale_date DATE, -- Date the sale occurred + quantity INTEGER, -- Quantity of product sold + FOREIGN KEY (product_id) REFERENCES products (product_id), + FOREIGN KEY (customer_id) REFERENCES customers (customer_id), + FOREIGN KEY (salesperson_id) REFERENCES salespeople (salesperson_id) +); +DROP TABLE IF EXISTS product_suppliers; +CREATE TABLE IF NOT EXISTS product_suppliers +( + supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier + product_id INTEGER, -- Product ID supplied + supply_price DECIMAL(10, 2), -- Unit price charged by supplier + FOREIGN KEY (product_id) REFERENCES products (product_id) +); + + +Generate only a valid (syntactically correct) executable Postgres SQL query (without any explanation of the query) for the following question: +``: +""" \ No newline at end of file diff --git a/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java b/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java index 22880f2..9d34bd4 100644 --- a/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java +++ b/src/test/java/io/github/amithkoujalgi/ollama4j/TestMockedAPIs.java @@ -1,5 +1,8 @@ package io.github.amithkoujalgi.ollama4j; +import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; +import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; +import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; import org.apache.hc.core5.http.ParseException; import org.junit.jupiter.api.Test; import org.mockito.Mockito; @@ -12,7 +15,7 @@ public class TestMockedAPIs { @Test public void testMockPullModel() { OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); - OllamaModel model = OllamaModel.LLAMA2; + String model = OllamaModelType.LLAMA2; try { doNothing().when(ollamaAPI).pullModel(model); ollamaAPI.pullModel(model);