forked from Mirror/ollama4j
init
This commit is contained in:
parent
c1615a2005
commit
153f516d9f
20
src/main/java/io/github/amithkoujalgi/ollama4j/Main.java
Normal file
20
src/main/java/io/github/amithkoujalgi/ollama4j/Main.java
Normal file
@ -0,0 +1,20 @@
|
||||
package io.github.amithkoujalgi.ollama4j;
|
||||
|
||||
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
|
||||
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
|
||||
import io.github.amithkoujalgi.ollama4j.core.utils.SamplePrompts;
|
||||
|
||||
public class Main {
|
||||
public static void main(String[] args) throws Exception {
|
||||
String host = "http://localhost:11434/";
|
||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
||||
|
||||
String prompt1 = SamplePrompts.getSampleDatabasePromptWithQuestion("List all customer names who have bought one or more products");
|
||||
String response1 = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt1);
|
||||
System.out.println(response1);
|
||||
|
||||
String prompt2 = "Give me a list of world cup cricket teams.";
|
||||
String response2 = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt2);
|
||||
System.out.println(response2);
|
||||
}
|
||||
}
|
@ -1,102 +0,0 @@
|
||||
package io.github.amithkoujalgi.ollama4j;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import org.apache.hc.client5.http.classic.methods.HttpPost;
|
||||
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
|
||||
import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse;
|
||||
import org.apache.hc.client5.http.impl.classic.HttpClients;
|
||||
import org.apache.hc.core5.http.HttpEntity;
|
||||
import org.apache.hc.core5.http.ParseException;
|
||||
import org.apache.hc.core5.http.io.entity.EntityUtils;
|
||||
import org.apache.hc.core5.http.io.entity.StringEntity;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.net.URL;
|
||||
|
||||
@SuppressWarnings("deprecation")
|
||||
public class OllamaAPI {
|
||||
private final String host;
|
||||
|
||||
public OllamaAPI(String host) {
|
||||
if (host.endsWith("/")) {
|
||||
this.host = host.substring(0, host.length() - 1);
|
||||
} else {
|
||||
this.host = host;
|
||||
}
|
||||
}
|
||||
|
||||
public void pullModel(OllamaModel model) throws IOException, ParseException, OllamaBaseException {
|
||||
String url = this.host + "/api/pull";
|
||||
String jsonData = String.format("{\"name\": \"%s\"}", model.getModel());
|
||||
final HttpPost httpPost = new HttpPost(url);
|
||||
final StringEntity entity = new StringEntity(jsonData);
|
||||
httpPost.setEntity(entity);
|
||||
httpPost.setHeader("Accept", "application/json");
|
||||
httpPost.setHeader("Content-type", "application/json");
|
||||
try (CloseableHttpClient client = HttpClients.createDefault();
|
||||
CloseableHttpResponse response = client.execute(httpPost)) {
|
||||
final int statusCode = response.getCode();
|
||||
HttpEntity responseEntity = response.getEntity();
|
||||
String responseString = "";
|
||||
if (responseEntity != null) {
|
||||
responseString = EntityUtils.toString(responseEntity, "UTF-8");
|
||||
}
|
||||
if (statusCode == 200) {
|
||||
System.out.println(responseString);
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public String runSync(OllamaModel ollamaModel, String promptText) throws OllamaBaseException, IOException {
|
||||
Gson gson = new Gson();
|
||||
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModel.getModel(), promptText);
|
||||
URL obj = new URL(this.host + "/api/generate");
|
||||
HttpURLConnection con = (HttpURLConnection) obj.openConnection();
|
||||
con.setRequestMethod("POST");
|
||||
con.setDoOutput(true);
|
||||
con.setRequestProperty("Content-Type", "application/json");
|
||||
try (DataOutputStream wr = new DataOutputStream(con.getOutputStream())) {
|
||||
wr.writeBytes(ollamaRequestModel.toString());
|
||||
}
|
||||
int responseCode = con.getResponseCode();
|
||||
if (responseCode == HttpURLConnection.HTTP_OK) {
|
||||
try (BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream()))) {
|
||||
String inputLine;
|
||||
StringBuilder response = new StringBuilder();
|
||||
while ((inputLine = in.readLine()) != null) {
|
||||
OllamaResponseModel ollamaResponseModel = gson.fromJson(inputLine, OllamaResponseModel.class);
|
||||
if (!ollamaResponseModel.getDone()) {
|
||||
response.append(ollamaResponseModel.getResponse());
|
||||
}
|
||||
System.out.println("Streamed response line: " + ollamaResponseModel.getResponse());
|
||||
}
|
||||
in.close();
|
||||
return response.toString();
|
||||
}
|
||||
} else {
|
||||
throw new OllamaBaseException(con.getResponseCode() + " - " + con.getResponseMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public OllamaAsyncResultCallback runAsync(OllamaModel ollamaModel, String promptText) throws OllamaBaseException, IOException {
|
||||
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModel.getModel(), promptText);
|
||||
URL obj = new URL(this.host + "/api/generate");
|
||||
HttpURLConnection con = (HttpURLConnection) obj.openConnection();
|
||||
con.setRequestMethod("POST");
|
||||
con.setDoOutput(true);
|
||||
con.setRequestProperty("Content-Type", "application/json");
|
||||
try (DataOutputStream wr = new DataOutputStream(con.getOutputStream())) {
|
||||
wr.writeBytes(ollamaRequestModel.toString());
|
||||
}
|
||||
OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(con);
|
||||
ollamaAsyncResultCallback.start();
|
||||
return ollamaAsyncResultCallback;
|
||||
}
|
||||
}
|
||||
|
@ -1,15 +0,0 @@
|
||||
package io.github.amithkoujalgi.ollama4j;
|
||||
|
||||
public enum OllamaModel {
|
||||
LLAMA2("llama2"), MISTRAL("mistral"), MEDLLAMA2("medllama2"), CODELLAMA("codellama"), VICUNA("vicuna"), ORCAMINI("orca-mini"), SQLCODER("sqlcoder"), WIZARDMATH("wizard-math");
|
||||
|
||||
private final String model;
|
||||
|
||||
OllamaModel(String model) {
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public String getModel() {
|
||||
return model;
|
||||
}
|
||||
}
|
@ -0,0 +1,199 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||
import io.github.amithkoujalgi.ollama4j.core.models.*;
|
||||
import org.apache.hc.client5.http.classic.methods.HttpDelete;
|
||||
import org.apache.hc.client5.http.classic.methods.HttpGet;
|
||||
import org.apache.hc.client5.http.classic.methods.HttpPost;
|
||||
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
|
||||
import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse;
|
||||
import org.apache.hc.client5.http.impl.classic.HttpClients;
|
||||
import org.apache.hc.core5.http.HttpEntity;
|
||||
import org.apache.hc.core5.http.ParseException;
|
||||
import org.apache.hc.core5.http.io.entity.EntityUtils;
|
||||
import org.apache.hc.core5.http.io.entity.StringEntity;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.net.URL;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@SuppressWarnings({"DuplicatedCode", "ExtractMethodRecommender"})
|
||||
public class OllamaAPI {
|
||||
|
||||
|
||||
private final String host;
|
||||
|
||||
public OllamaAPI(String host) {
|
||||
if (host.endsWith("/")) {
|
||||
this.host = host.substring(0, host.length() - 1);
|
||||
} else {
|
||||
this.host = host;
|
||||
}
|
||||
}
|
||||
|
||||
public List<Model> listModels() throws IOException, OllamaBaseException, ParseException {
|
||||
String url = this.host + "/api/tags";
|
||||
final HttpGet httpGet = new HttpGet(url);
|
||||
httpGet.setHeader("Accept", "application/json");
|
||||
httpGet.setHeader("Content-type", "application/json");
|
||||
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpGet)) {
|
||||
final int statusCode = response.getCode();
|
||||
HttpEntity responseEntity = response.getEntity();
|
||||
String responseString = "";
|
||||
if (responseEntity != null) {
|
||||
responseString = EntityUtils.toString(responseEntity, "UTF-8");
|
||||
}
|
||||
if (statusCode == 200) {
|
||||
Models m = new Gson().fromJson(responseString, Models.class);
|
||||
return m.getModels();
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public ModelDetail getModelDetails(Model model) throws IOException, OllamaBaseException, ParseException {
|
||||
String url = this.host + "/api/show";
|
||||
String jsonData = String.format("{\"name\": \"%s\"}", model.getName());
|
||||
final HttpPost httpPost = new HttpPost(url);
|
||||
final StringEntity entity = new StringEntity(jsonData);
|
||||
httpPost.setEntity(entity);
|
||||
httpPost.setHeader("Accept", "application/json");
|
||||
httpPost.setHeader("Content-type", "application/json");
|
||||
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) {
|
||||
final int statusCode = response.getCode();
|
||||
HttpEntity responseEntity = response.getEntity();
|
||||
String responseString = "";
|
||||
if (responseEntity != null) {
|
||||
responseString = EntityUtils.toString(responseEntity, "UTF-8");
|
||||
}
|
||||
if (statusCode == 200) {
|
||||
return new Gson().fromJson(responseString, ModelDetail.class);
|
||||
} else {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void pullModel(String model) throws IOException, ParseException, OllamaBaseException {
|
||||
List<Model> models = listModels().stream().filter(m -> m.getModelName().split(":")[0].equals(model)).collect(Collectors.toList());
|
||||
if (!models.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
String url = this.host + "/api/pull";
|
||||
String jsonData = String.format("{\"name\": \"%s\"}", model);
|
||||
final HttpPost httpPost = new HttpPost(url);
|
||||
final StringEntity entity = new StringEntity(jsonData);
|
||||
httpPost.setEntity(entity);
|
||||
httpPost.setHeader("Accept", "application/json");
|
||||
httpPost.setHeader("Content-type", "application/json");
|
||||
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) {
|
||||
final int statusCode = response.getCode();
|
||||
HttpEntity responseEntity = response.getEntity();
|
||||
String responseString = "";
|
||||
if (responseEntity != null) {
|
||||
responseString = EntityUtils.toString(responseEntity, "UTF-8");
|
||||
}
|
||||
if (statusCode != 200) {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void createModel(String name, String modelFilePath) throws IOException, ParseException, OllamaBaseException {
|
||||
String url = this.host + "/api/create";
|
||||
String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", name, modelFilePath);
|
||||
final HttpPost httpPost = new HttpPost(url);
|
||||
final StringEntity entity = new StringEntity(jsonData);
|
||||
httpPost.setEntity(entity);
|
||||
httpPost.setHeader("Accept", "application/json");
|
||||
httpPost.setHeader("Content-type", "application/json");
|
||||
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) {
|
||||
final int statusCode = response.getCode();
|
||||
HttpEntity responseEntity = response.getEntity();
|
||||
String responseString = "";
|
||||
if (responseEntity != null) {
|
||||
responseString = EntityUtils.toString(responseEntity, "UTF-8");
|
||||
}
|
||||
if (statusCode != 200) {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void deleteModel(String name, boolean ignoreIfNotPresent) throws IOException, ParseException, OllamaBaseException {
|
||||
String url = this.host + "/api/delete";
|
||||
String jsonData = String.format("{\"name\": \"%s\"}", name);
|
||||
final HttpDelete httpDelete = new HttpDelete(url);
|
||||
final StringEntity entity = new StringEntity(jsonData);
|
||||
httpDelete.setEntity(entity);
|
||||
httpDelete.setHeader("Accept", "application/json");
|
||||
httpDelete.setHeader("Content-type", "application/json");
|
||||
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpDelete)) {
|
||||
final int statusCode = response.getCode();
|
||||
HttpEntity responseEntity = response.getEntity();
|
||||
String responseString = "";
|
||||
if (responseEntity != null) {
|
||||
responseString = EntityUtils.toString(responseEntity, "UTF-8");
|
||||
}
|
||||
if (statusCode == 404 && responseString.contains("model") && responseString.contains("not found")) {
|
||||
return;
|
||||
}
|
||||
if (statusCode != 200) {
|
||||
throw new OllamaBaseException(statusCode + " - " + responseString);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException {
|
||||
Gson gson = new Gson();
|
||||
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
|
||||
URL obj = new URL(this.host + "/api/generate");
|
||||
HttpURLConnection con = (HttpURLConnection) obj.openConnection();
|
||||
con.setRequestMethod("POST");
|
||||
con.setDoOutput(true);
|
||||
con.setRequestProperty("Content-Type", "application/json");
|
||||
try (DataOutputStream wr = new DataOutputStream(con.getOutputStream())) {
|
||||
wr.writeBytes(ollamaRequestModel.toString());
|
||||
}
|
||||
int responseCode = con.getResponseCode();
|
||||
if (responseCode == HttpURLConnection.HTTP_OK) {
|
||||
try (BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream()))) {
|
||||
String inputLine;
|
||||
StringBuilder response = new StringBuilder();
|
||||
while ((inputLine = in.readLine()) != null) {
|
||||
OllamaResponseModel ollamaResponseModel = gson.fromJson(inputLine, OllamaResponseModel.class);
|
||||
if (!ollamaResponseModel.getDone()) {
|
||||
response.append(ollamaResponseModel.getResponse());
|
||||
}
|
||||
}
|
||||
in.close();
|
||||
return response.toString();
|
||||
}
|
||||
} else {
|
||||
throw new OllamaBaseException(con.getResponseCode() + " - " + con.getResponseMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public OllamaAsyncResultCallback askAsync(String ollamaModelType, String promptText) throws IOException {
|
||||
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
|
||||
URL obj = new URL(this.host + "/api/generate");
|
||||
HttpURLConnection con = (HttpURLConnection) obj.openConnection();
|
||||
con.setRequestMethod("POST");
|
||||
con.setDoOutput(true);
|
||||
con.setRequestProperty("Content-Type", "application/json");
|
||||
try (DataOutputStream wr = new DataOutputStream(con.getOutputStream())) {
|
||||
wr.writeBytes(ollamaRequestModel.toString());
|
||||
}
|
||||
OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(con);
|
||||
ollamaAsyncResultCallback.start();
|
||||
return ollamaAsyncResultCallback;
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package io.github.amithkoujalgi.ollama4j;
|
||||
package io.github.amithkoujalgi.ollama4j.core.exceptions;
|
||||
|
||||
public class OllamaBaseException extends Exception {
|
||||
|
@ -0,0 +1,47 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||
|
||||
public class Model {
|
||||
private String name, modified_at, digest;
|
||||
private Long size;
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public String getModelName() {
|
||||
return name.split(":")[0];
|
||||
}
|
||||
|
||||
public String getModelVersion() {
|
||||
return name.split(":")[1];
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public String getModified_at() {
|
||||
return modified_at;
|
||||
}
|
||||
|
||||
public void setModified_at(String modified_at) {
|
||||
this.modified_at = modified_at;
|
||||
}
|
||||
|
||||
public String getDigest() {
|
||||
return digest;
|
||||
}
|
||||
|
||||
public void setDigest(String digest) {
|
||||
this.digest = digest;
|
||||
}
|
||||
|
||||
public Long getSize() {
|
||||
return size;
|
||||
}
|
||||
|
||||
public void setSize(Long size) {
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||
|
||||
public class ModelDetail {
|
||||
private String license, modelfile, parameters, template;
|
||||
|
||||
public String getLicense() {
|
||||
return license;
|
||||
}
|
||||
|
||||
public void setLicense(String license) {
|
||||
this.license = license;
|
||||
}
|
||||
|
||||
public String getModelfile() {
|
||||
return modelfile;
|
||||
}
|
||||
|
||||
public void setModelfile(String modelfile) {
|
||||
this.modelfile = modelfile;
|
||||
}
|
||||
|
||||
public String getParameters() {
|
||||
return parameters;
|
||||
}
|
||||
|
||||
public void setParameters(String parameters) {
|
||||
this.parameters = parameters;
|
||||
}
|
||||
|
||||
public String getTemplate() {
|
||||
return template;
|
||||
}
|
||||
|
||||
public void setTemplate(String template) {
|
||||
this.template = template;
|
||||
}
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class Models {
|
||||
private List<Model> models;
|
||||
|
||||
public List<Model> getModels() {
|
||||
return models;
|
||||
}
|
||||
|
||||
public void setModels(List<Model> models) {
|
||||
this.models = models;
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
package io.github.amithkoujalgi.ollama4j;
|
||||
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
@ -25,22 +26,26 @@ public class OllamaAsyncResultCallback extends Thread {
|
||||
try {
|
||||
responseCode = this.connection.getResponseCode();
|
||||
if (responseCode == HttpURLConnection.HTTP_OK) {
|
||||
try (BufferedReader in = new BufferedReader(new InputStreamReader(this.connection.getInputStream()))) {
|
||||
try (BufferedReader in =
|
||||
new BufferedReader(new InputStreamReader(this.connection.getInputStream()))) {
|
||||
String inputLine;
|
||||
StringBuilder response = new StringBuilder();
|
||||
while ((inputLine = in.readLine()) != null) {
|
||||
OllamaResponseModel ollamaResponseModel = gson.fromJson(inputLine, OllamaResponseModel.class);
|
||||
OllamaResponseModel ollamaResponseModel =
|
||||
gson.fromJson(inputLine, OllamaResponseModel.class);
|
||||
if (!ollamaResponseModel.getDone()) {
|
||||
response.append(ollamaResponseModel.getResponse());
|
||||
}
|
||||
// System.out.println("Streamed response line: " + responseModel.getResponse());
|
||||
// System.out.println("Streamed response line: " +
|
||||
// responseModel.getResponse());
|
||||
}
|
||||
in.close();
|
||||
this.isDone = true;
|
||||
this.result = response.toString();
|
||||
}
|
||||
} else {
|
||||
throw new OllamaBaseException(connection.getResponseCode() + " - " + connection.getResponseMessage());
|
||||
throw new OllamaBaseException(
|
||||
connection.getResponseCode() + " - " + connection.getResponseMessage());
|
||||
}
|
||||
} catch (IOException | OllamaBaseException e) {
|
||||
this.isDone = true;
|
@ -1,4 +1,4 @@
|
||||
package io.github.amithkoujalgi.ollama4j;
|
||||
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
|
@ -1,9 +1,8 @@
|
||||
package io.github.amithkoujalgi.ollama4j;
|
||||
package io.github.amithkoujalgi.ollama4j.core.models;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public
|
||||
class OllamaResponseModel {
|
||||
public class OllamaResponseModel {
|
||||
private String model;
|
||||
private String created_at;
|
||||
private String response;
|
@ -0,0 +1,12 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.types;
|
||||
|
||||
public class OllamaModelType {
|
||||
public static String LLAMA2 = "llama2";
|
||||
public static String MISTRAL = "mistral";
|
||||
public static String MEDLLAMA2 = "medllama2";
|
||||
public static String CODELLAMA = "codellama";
|
||||
public static String VICUNA = "vicuna";
|
||||
public static String ORCAMINI = "orca-mini";
|
||||
public static String SQLCODER = "sqlcoder";
|
||||
public static String WIZARDMATH = "wizard-math";
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package io.github.amithkoujalgi.ollama4j.core.utils;
|
||||
|
||||
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.Scanner;
|
||||
|
||||
public class SamplePrompts {
|
||||
public static String getSampleDatabasePromptWithQuestion(String question) throws Exception {
|
||||
ClassLoader classLoader = OllamaAPI.class.getClassLoader();
|
||||
InputStream inputStream = classLoader.getResourceAsStream("sample-db-prompt-template.txt");
|
||||
if (inputStream != null) {
|
||||
Scanner scanner = new Scanner(inputStream);
|
||||
StringBuilder stringBuffer = new StringBuilder();
|
||||
while (scanner.hasNextLine()) {
|
||||
stringBuffer.append(scanner.nextLine()).append("\n");
|
||||
}
|
||||
scanner.close();
|
||||
return stringBuffer.toString().replaceAll("<question>", question);
|
||||
} else {
|
||||
throw new Exception("Sample database question file not found.");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
61
src/main/resources/sample-db-prompt-template.txt
Normal file
61
src/main/resources/sample-db-prompt-template.txt
Normal file
@ -0,0 +1,61 @@
|
||||
"""
|
||||
Following is the database schema.
|
||||
|
||||
DROP TABLE IF EXISTS product_categories;
|
||||
CREATE TABLE IF NOT EXISTS product_categories
|
||||
(
|
||||
category_id INTEGER PRIMARY KEY, -- Unique ID for each category
|
||||
name VARCHAR(50), -- Name of the category
|
||||
parent INTEGER NULL, -- Parent category - for hierarchical categories
|
||||
FOREIGN KEY (parent) REFERENCES product_categories (category_id)
|
||||
);
|
||||
DROP TABLE IF EXISTS products;
|
||||
CREATE TABLE IF NOT EXISTS products
|
||||
(
|
||||
product_id INTEGER PRIMARY KEY, -- Unique ID for each product
|
||||
name VARCHAR(50), -- Name of the product
|
||||
price DECIMAL(10, 2), -- Price of each unit of the product
|
||||
quantity INTEGER, -- Current quantity in stock
|
||||
category_id INTEGER, -- Unique ID for each product
|
||||
FOREIGN KEY (category_id) REFERENCES product_categories (category_id)
|
||||
);
|
||||
DROP TABLE IF EXISTS customers;
|
||||
CREATE TABLE IF NOT EXISTS customers
|
||||
(
|
||||
customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
|
||||
name VARCHAR(50), -- Name of the customer
|
||||
address VARCHAR(100) -- Mailing address of the customer
|
||||
);
|
||||
DROP TABLE IF EXISTS salespeople;
|
||||
CREATE TABLE IF NOT EXISTS salespeople
|
||||
(
|
||||
salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
|
||||
name VARCHAR(50), -- Name of the salesperson
|
||||
region VARCHAR(50) -- Geographic sales region
|
||||
);
|
||||
DROP TABLE IF EXISTS sales;
|
||||
CREATE TABLE IF NOT EXISTS sales
|
||||
(
|
||||
sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
|
||||
product_id INTEGER, -- ID of product sold
|
||||
customer_id INTEGER, -- ID of customer who made the purchase
|
||||
salesperson_id INTEGER, -- ID of salesperson who made the sale
|
||||
sale_date DATE, -- Date the sale occurred
|
||||
quantity INTEGER, -- Quantity of product sold
|
||||
FOREIGN KEY (product_id) REFERENCES products (product_id),
|
||||
FOREIGN KEY (customer_id) REFERENCES customers (customer_id),
|
||||
FOREIGN KEY (salesperson_id) REFERENCES salespeople (salesperson_id)
|
||||
);
|
||||
DROP TABLE IF EXISTS product_suppliers;
|
||||
CREATE TABLE IF NOT EXISTS product_suppliers
|
||||
(
|
||||
supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
|
||||
product_id INTEGER, -- Product ID supplied
|
||||
supply_price DECIMAL(10, 2), -- Unit price charged by supplier
|
||||
FOREIGN KEY (product_id) REFERENCES products (product_id)
|
||||
);
|
||||
|
||||
|
||||
Generate only a valid (syntactically correct) executable Postgres SQL query (without any explanation of the query) for the following question:
|
||||
`<question>`:
|
||||
"""
|
@ -1,5 +1,8 @@
|
||||
package io.github.amithkoujalgi.ollama4j;
|
||||
|
||||
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
|
||||
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
|
||||
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
|
||||
import org.apache.hc.core5.http.ParseException;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.Mockito;
|
||||
@ -12,7 +15,7 @@ public class TestMockedAPIs {
|
||||
@Test
|
||||
public void testMockPullModel() {
|
||||
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
|
||||
OllamaModel model = OllamaModel.LLAMA2;
|
||||
String model = OllamaModelType.LLAMA2;
|
||||
try {
|
||||
doNothing().when(ollamaAPI).pullModel(model);
|
||||
ollamaAPI.pullModel(model);
|
||||
|
Loading…
x
Reference in New Issue
Block a user