This commit is contained in:
Amith Koujalgi 2023-11-07 20:52:42 +05:30
parent c1615a2005
commit 153f516d9f
15 changed files with 434 additions and 128 deletions

View File

@ -0,0 +1,20 @@
package io.github.amithkoujalgi.ollama4j;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import io.github.amithkoujalgi.ollama4j.core.utils.SamplePrompts;
public class Main {
public static void main(String[] args) throws Exception {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
String prompt1 = SamplePrompts.getSampleDatabasePromptWithQuestion("List all customer names who have bought one or more products");
String response1 = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt1);
System.out.println(response1);
String prompt2 = "Give me a list of world cup cricket teams.";
String response2 = ollamaAPI.ask(OllamaModelType.LLAMA2, prompt2);
System.out.println(response2);
}
}

View File

@ -1,102 +0,0 @@
package io.github.amithkoujalgi.ollama4j;
import com.google.gson.Gson;
import org.apache.hc.client5.http.classic.methods.HttpPost;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse;
import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.ParseException;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.http.io.entity.StringEntity;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
@SuppressWarnings("deprecation")
public class OllamaAPI {
private final String host;
public OllamaAPI(String host) {
if (host.endsWith("/")) {
this.host = host.substring(0, host.length() - 1);
} else {
this.host = host;
}
}
public void pullModel(OllamaModel model) throws IOException, ParseException, OllamaBaseException {
String url = this.host + "/api/pull";
String jsonData = String.format("{\"name\": \"%s\"}", model.getModel());
final HttpPost httpPost = new HttpPost(url);
final StringEntity entity = new StringEntity(jsonData);
httpPost.setEntity(entity);
httpPost.setHeader("Accept", "application/json");
httpPost.setHeader("Content-type", "application/json");
try (CloseableHttpClient client = HttpClients.createDefault();
CloseableHttpResponse response = client.execute(httpPost)) {
final int statusCode = response.getCode();
HttpEntity responseEntity = response.getEntity();
String responseString = "";
if (responseEntity != null) {
responseString = EntityUtils.toString(responseEntity, "UTF-8");
}
if (statusCode == 200) {
System.out.println(responseString);
} else {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
}
}
public String runSync(OllamaModel ollamaModel, String promptText) throws OllamaBaseException, IOException {
Gson gson = new Gson();
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModel.getModel(), promptText);
URL obj = new URL(this.host + "/api/generate");
HttpURLConnection con = (HttpURLConnection) obj.openConnection();
con.setRequestMethod("POST");
con.setDoOutput(true);
con.setRequestProperty("Content-Type", "application/json");
try (DataOutputStream wr = new DataOutputStream(con.getOutputStream())) {
wr.writeBytes(ollamaRequestModel.toString());
}
int responseCode = con.getResponseCode();
if (responseCode == HttpURLConnection.HTTP_OK) {
try (BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream()))) {
String inputLine;
StringBuilder response = new StringBuilder();
while ((inputLine = in.readLine()) != null) {
OllamaResponseModel ollamaResponseModel = gson.fromJson(inputLine, OllamaResponseModel.class);
if (!ollamaResponseModel.getDone()) {
response.append(ollamaResponseModel.getResponse());
}
System.out.println("Streamed response line: " + ollamaResponseModel.getResponse());
}
in.close();
return response.toString();
}
} else {
throw new OllamaBaseException(con.getResponseCode() + " - " + con.getResponseMessage());
}
}
public OllamaAsyncResultCallback runAsync(OllamaModel ollamaModel, String promptText) throws OllamaBaseException, IOException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModel.getModel(), promptText);
URL obj = new URL(this.host + "/api/generate");
HttpURLConnection con = (HttpURLConnection) obj.openConnection();
con.setRequestMethod("POST");
con.setDoOutput(true);
con.setRequestProperty("Content-Type", "application/json");
try (DataOutputStream wr = new DataOutputStream(con.getOutputStream())) {
wr.writeBytes(ollamaRequestModel.toString());
}
OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(con);
ollamaAsyncResultCallback.start();
return ollamaAsyncResultCallback;
}
}

View File

@ -1,15 +0,0 @@
package io.github.amithkoujalgi.ollama4j;
public enum OllamaModel {
LLAMA2("llama2"), MISTRAL("mistral"), MEDLLAMA2("medllama2"), CODELLAMA("codellama"), VICUNA("vicuna"), ORCAMINI("orca-mini"), SQLCODER("sqlcoder"), WIZARDMATH("wizard-math");
private final String model;
OllamaModel(String model) {
this.model = model;
}
public String getModel() {
return model;
}
}

View File

@ -0,0 +1,199 @@
package io.github.amithkoujalgi.ollama4j.core;
import com.google.gson.Gson;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.*;
import org.apache.hc.client5.http.classic.methods.HttpDelete;
import org.apache.hc.client5.http.classic.methods.HttpGet;
import org.apache.hc.client5.http.classic.methods.HttpPost;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse;
import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.ParseException;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.http.io.entity.StringEntity;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.List;
import java.util.stream.Collectors;
@SuppressWarnings({"DuplicatedCode", "ExtractMethodRecommender"})
public class OllamaAPI {
private final String host;
public OllamaAPI(String host) {
if (host.endsWith("/")) {
this.host = host.substring(0, host.length() - 1);
} else {
this.host = host;
}
}
public List<Model> listModels() throws IOException, OllamaBaseException, ParseException {
String url = this.host + "/api/tags";
final HttpGet httpGet = new HttpGet(url);
httpGet.setHeader("Accept", "application/json");
httpGet.setHeader("Content-type", "application/json");
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpGet)) {
final int statusCode = response.getCode();
HttpEntity responseEntity = response.getEntity();
String responseString = "";
if (responseEntity != null) {
responseString = EntityUtils.toString(responseEntity, "UTF-8");
}
if (statusCode == 200) {
Models m = new Gson().fromJson(responseString, Models.class);
return m.getModels();
} else {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
}
}
public ModelDetail getModelDetails(Model model) throws IOException, OllamaBaseException, ParseException {
String url = this.host + "/api/show";
String jsonData = String.format("{\"name\": \"%s\"}", model.getName());
final HttpPost httpPost = new HttpPost(url);
final StringEntity entity = new StringEntity(jsonData);
httpPost.setEntity(entity);
httpPost.setHeader("Accept", "application/json");
httpPost.setHeader("Content-type", "application/json");
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) {
final int statusCode = response.getCode();
HttpEntity responseEntity = response.getEntity();
String responseString = "";
if (responseEntity != null) {
responseString = EntityUtils.toString(responseEntity, "UTF-8");
}
if (statusCode == 200) {
return new Gson().fromJson(responseString, ModelDetail.class);
} else {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
}
}
public void pullModel(String model) throws IOException, ParseException, OllamaBaseException {
List<Model> models = listModels().stream().filter(m -> m.getModelName().split(":")[0].equals(model)).collect(Collectors.toList());
if (!models.isEmpty()) {
return;
}
String url = this.host + "/api/pull";
String jsonData = String.format("{\"name\": \"%s\"}", model);
final HttpPost httpPost = new HttpPost(url);
final StringEntity entity = new StringEntity(jsonData);
httpPost.setEntity(entity);
httpPost.setHeader("Accept", "application/json");
httpPost.setHeader("Content-type", "application/json");
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) {
final int statusCode = response.getCode();
HttpEntity responseEntity = response.getEntity();
String responseString = "";
if (responseEntity != null) {
responseString = EntityUtils.toString(responseEntity, "UTF-8");
}
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
}
}
public void createModel(String name, String modelFilePath) throws IOException, ParseException, OllamaBaseException {
String url = this.host + "/api/create";
String jsonData = String.format("{\"name\": \"%s\", \"path\": \"%s\"}", name, modelFilePath);
final HttpPost httpPost = new HttpPost(url);
final StringEntity entity = new StringEntity(jsonData);
httpPost.setEntity(entity);
httpPost.setHeader("Accept", "application/json");
httpPost.setHeader("Content-type", "application/json");
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpPost)) {
final int statusCode = response.getCode();
HttpEntity responseEntity = response.getEntity();
String responseString = "";
if (responseEntity != null) {
responseString = EntityUtils.toString(responseEntity, "UTF-8");
}
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
}
}
public void deleteModel(String name, boolean ignoreIfNotPresent) throws IOException, ParseException, OllamaBaseException {
String url = this.host + "/api/delete";
String jsonData = String.format("{\"name\": \"%s\"}", name);
final HttpDelete httpDelete = new HttpDelete(url);
final StringEntity entity = new StringEntity(jsonData);
httpDelete.setEntity(entity);
httpDelete.setHeader("Accept", "application/json");
httpDelete.setHeader("Content-type", "application/json");
try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(httpDelete)) {
final int statusCode = response.getCode();
HttpEntity responseEntity = response.getEntity();
String responseString = "";
if (responseEntity != null) {
responseString = EntityUtils.toString(responseEntity, "UTF-8");
}
if (statusCode == 404 && responseString.contains("model") && responseString.contains("not found")) {
return;
}
if (statusCode != 200) {
throw new OllamaBaseException(statusCode + " - " + responseString);
}
}
}
public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException {
Gson gson = new Gson();
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
URL obj = new URL(this.host + "/api/generate");
HttpURLConnection con = (HttpURLConnection) obj.openConnection();
con.setRequestMethod("POST");
con.setDoOutput(true);
con.setRequestProperty("Content-Type", "application/json");
try (DataOutputStream wr = new DataOutputStream(con.getOutputStream())) {
wr.writeBytes(ollamaRequestModel.toString());
}
int responseCode = con.getResponseCode();
if (responseCode == HttpURLConnection.HTTP_OK) {
try (BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream()))) {
String inputLine;
StringBuilder response = new StringBuilder();
while ((inputLine = in.readLine()) != null) {
OllamaResponseModel ollamaResponseModel = gson.fromJson(inputLine, OllamaResponseModel.class);
if (!ollamaResponseModel.getDone()) {
response.append(ollamaResponseModel.getResponse());
}
}
in.close();
return response.toString();
}
} else {
throw new OllamaBaseException(con.getResponseCode() + " - " + con.getResponseMessage());
}
}
public OllamaAsyncResultCallback askAsync(String ollamaModelType, String promptText) throws IOException {
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
URL obj = new URL(this.host + "/api/generate");
HttpURLConnection con = (HttpURLConnection) obj.openConnection();
con.setRequestMethod("POST");
con.setDoOutput(true);
con.setRequestProperty("Content-Type", "application/json");
try (DataOutputStream wr = new DataOutputStream(con.getOutputStream())) {
wr.writeBytes(ollamaRequestModel.toString());
}
OllamaAsyncResultCallback ollamaAsyncResultCallback = new OllamaAsyncResultCallback(con);
ollamaAsyncResultCallback.start();
return ollamaAsyncResultCallback;
}
}

View File

@ -1,4 +1,4 @@
package io.github.amithkoujalgi.ollama4j;
package io.github.amithkoujalgi.ollama4j.core.exceptions;
public class OllamaBaseException extends Exception {

View File

@ -0,0 +1,47 @@
package io.github.amithkoujalgi.ollama4j.core.models;
public class Model {
private String name, modified_at, digest;
private Long size;
public String getName() {
return name;
}
public String getModelName() {
return name.split(":")[0];
}
public String getModelVersion() {
return name.split(":")[1];
}
public void setName(String name) {
this.name = name;
}
public String getModified_at() {
return modified_at;
}
public void setModified_at(String modified_at) {
this.modified_at = modified_at;
}
public String getDigest() {
return digest;
}
public void setDigest(String digest) {
this.digest = digest;
}
public Long getSize() {
return size;
}
public void setSize(Long size) {
this.size = size;
}
}

View File

@ -0,0 +1,37 @@
package io.github.amithkoujalgi.ollama4j.core.models;
public class ModelDetail {
private String license, modelfile, parameters, template;
public String getLicense() {
return license;
}
public void setLicense(String license) {
this.license = license;
}
public String getModelfile() {
return modelfile;
}
public void setModelfile(String modelfile) {
this.modelfile = modelfile;
}
public String getParameters() {
return parameters;
}
public void setParameters(String parameters) {
this.parameters = parameters;
}
public String getTemplate() {
return template;
}
public void setTemplate(String template) {
this.template = template;
}
}

View File

@ -0,0 +1,15 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import java.util.List;
public class Models {
private List<Model> models;
public List<Model> getModels() {
return models;
}
public void setModels(List<Model> models) {
this.models = models;
}
}

View File

@ -1,6 +1,7 @@
package io.github.amithkoujalgi.ollama4j;
package io.github.amithkoujalgi.ollama4j.core.models;
import com.google.gson.Gson;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import java.io.BufferedReader;
import java.io.IOException;
@ -25,22 +26,26 @@ public class OllamaAsyncResultCallback extends Thread {
try {
responseCode = this.connection.getResponseCode();
if (responseCode == HttpURLConnection.HTTP_OK) {
try (BufferedReader in = new BufferedReader(new InputStreamReader(this.connection.getInputStream()))) {
try (BufferedReader in =
new BufferedReader(new InputStreamReader(this.connection.getInputStream()))) {
String inputLine;
StringBuilder response = new StringBuilder();
while ((inputLine = in.readLine()) != null) {
OllamaResponseModel ollamaResponseModel = gson.fromJson(inputLine, OllamaResponseModel.class);
OllamaResponseModel ollamaResponseModel =
gson.fromJson(inputLine, OllamaResponseModel.class);
if (!ollamaResponseModel.getDone()) {
response.append(ollamaResponseModel.getResponse());
}
// System.out.println("Streamed response line: " + responseModel.getResponse());
// System.out.println("Streamed response line: " +
// responseModel.getResponse());
}
in.close();
this.isDone = true;
this.result = response.toString();
}
} else {
throw new OllamaBaseException(connection.getResponseCode() + " - " + connection.getResponseMessage());
throw new OllamaBaseException(
connection.getResponseCode() + " - " + connection.getResponseMessage());
}
} catch (IOException | OllamaBaseException e) {
this.isDone = true;

View File

@ -1,4 +1,4 @@
package io.github.amithkoujalgi.ollama4j;
package io.github.amithkoujalgi.ollama4j.core.models;
import com.google.gson.Gson;

View File

@ -1,9 +1,8 @@
package io.github.amithkoujalgi.ollama4j;
package io.github.amithkoujalgi.ollama4j.core.models;
import java.util.List;
public
class OllamaResponseModel {
public class OllamaResponseModel {
private String model;
private String created_at;
private String response;

View File

@ -0,0 +1,12 @@
package io.github.amithkoujalgi.ollama4j.core.types;
public class OllamaModelType {
public static String LLAMA2 = "llama2";
public static String MISTRAL = "mistral";
public static String MEDLLAMA2 = "medllama2";
public static String CODELLAMA = "codellama";
public static String VICUNA = "vicuna";
public static String ORCAMINI = "orca-mini";
public static String SQLCODER = "sqlcoder";
public static String WIZARDMATH = "wizard-math";
}

View File

@ -0,0 +1,25 @@
package io.github.amithkoujalgi.ollama4j.core.utils;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import java.io.InputStream;
import java.util.Scanner;
public class SamplePrompts {
public static String getSampleDatabasePromptWithQuestion(String question) throws Exception {
ClassLoader classLoader = OllamaAPI.class.getClassLoader();
InputStream inputStream = classLoader.getResourceAsStream("sample-db-prompt-template.txt");
if (inputStream != null) {
Scanner scanner = new Scanner(inputStream);
StringBuilder stringBuffer = new StringBuilder();
while (scanner.hasNextLine()) {
stringBuffer.append(scanner.nextLine()).append("\n");
}
scanner.close();
return stringBuffer.toString().replaceAll("<question>", question);
} else {
throw new Exception("Sample database question file not found.");
}
}
}

View File

@ -0,0 +1,61 @@
"""
Following is the database schema.
DROP TABLE IF EXISTS product_categories;
CREATE TABLE IF NOT EXISTS product_categories
(
category_id INTEGER PRIMARY KEY, -- Unique ID for each category
name VARCHAR(50), -- Name of the category
parent INTEGER NULL, -- Parent category - for hierarchical categories
FOREIGN KEY (parent) REFERENCES product_categories (category_id)
);
DROP TABLE IF EXISTS products;
CREATE TABLE IF NOT EXISTS products
(
product_id INTEGER PRIMARY KEY, -- Unique ID for each product
name VARCHAR(50), -- Name of the product
price DECIMAL(10, 2), -- Price of each unit of the product
quantity INTEGER, -- Current quantity in stock
category_id INTEGER, -- Unique ID for each product
FOREIGN KEY (category_id) REFERENCES product_categories (category_id)
);
DROP TABLE IF EXISTS customers;
CREATE TABLE IF NOT EXISTS customers
(
customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
name VARCHAR(50), -- Name of the customer
address VARCHAR(100) -- Mailing address of the customer
);
DROP TABLE IF EXISTS salespeople;
CREATE TABLE IF NOT EXISTS salespeople
(
salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
name VARCHAR(50), -- Name of the salesperson
region VARCHAR(50) -- Geographic sales region
);
DROP TABLE IF EXISTS sales;
CREATE TABLE IF NOT EXISTS sales
(
sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
product_id INTEGER, -- ID of product sold
customer_id INTEGER, -- ID of customer who made the purchase
salesperson_id INTEGER, -- ID of salesperson who made the sale
sale_date DATE, -- Date the sale occurred
quantity INTEGER, -- Quantity of product sold
FOREIGN KEY (product_id) REFERENCES products (product_id),
FOREIGN KEY (customer_id) REFERENCES customers (customer_id),
FOREIGN KEY (salesperson_id) REFERENCES salespeople (salesperson_id)
);
DROP TABLE IF EXISTS product_suppliers;
CREATE TABLE IF NOT EXISTS product_suppliers
(
supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
product_id INTEGER, -- Product ID supplied
supply_price DECIMAL(10, 2), -- Unit price charged by supplier
FOREIGN KEY (product_id) REFERENCES products (product_id)
);
Generate only a valid (syntactically correct) executable Postgres SQL query (without any explanation of the query) for the following question:
`<question>`:
"""

View File

@ -1,5 +1,8 @@
package io.github.amithkoujalgi.ollama4j;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import org.apache.hc.core5.http.ParseException;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
@ -12,7 +15,7 @@ public class TestMockedAPIs {
@Test
public void testMockPullModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
OllamaModel model = OllamaModel.LLAMA2;
String model = OllamaModelType.LLAMA2;
try {
doNothing().when(ollamaAPI).pullModel(model);
ollamaAPI.pullModel(model);