- replaced GSON with Jackson

- Updated readme
- general cleanup
This commit is contained in:
Amith Koujalgi 2023-11-09 12:56:45 +05:30
parent 1f28e61234
commit 6678cd3f69
7 changed files with 73 additions and 22 deletions

View File

@ -279,6 +279,28 @@ FROM sales
GROUP BY customers.name; GROUP BY customers.name;
``` ```
#### Async API with streaming response
```java
public class Main {
public static void main(String[] args) throws Exception {
String host = "http://localhost:11434/";
OllamaAPI ollamaAPI = new OllamaAPI(host);
String prompt = "List all cricket world cup teams of 2019.";
OllamaAsyncResultCallback callback = ollamaAPI.askAsync(OllamaModelType.LLAMA2, prompt);
while (!callback.isComplete() || !callback.getStream().isEmpty()) {
// poll for data from the response stream
String response = callback.getStream().poll();
if (response != null) {
System.out.print(response);
}
Thread.sleep(1000);
}
}
}
```
#### API Spec #### API Spec
Find the full `Javadoc` (API specifications) [here](https://amithkoujalgi.github.io/ollama4j/). Find the full `Javadoc` (API specifications) [here](https://amithkoujalgi.github.io/ollama4j/).

View File

@ -102,9 +102,9 @@
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>com.google.code.gson</groupId> <groupId>com.fasterxml.jackson.core</groupId>
<artifactId>gson</artifactId> <artifactId>jackson-databind</artifactId>
<version>2.10.1</version> <version>2.15.3</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>

View File

@ -1,6 +1,6 @@
package io.github.amithkoujalgi.ollama4j.core; package io.github.amithkoujalgi.ollama4j.core;
import com.google.gson.Gson; import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.*; import io.github.amithkoujalgi.ollama4j.core.models.*;
import org.apache.hc.client5.http.classic.methods.HttpDelete; import org.apache.hc.client5.http.classic.methods.HttpDelete;
@ -30,10 +30,13 @@ import java.util.stream.Collectors;
*/ */
@SuppressWarnings({"DuplicatedCode", "ExtractMethodRecommender"}) @SuppressWarnings({"DuplicatedCode", "ExtractMethodRecommender"})
public class OllamaAPI { public class OllamaAPI {
private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class); private static final Logger logger = LoggerFactory.getLogger(OllamaAPI.class);
private final String host; private final String host;
private boolean verbose = false; private boolean verbose = false;
private final ObjectMapper objectMapper = new ObjectMapper();
/** /**
* Instantiates the Ollama API. * Instantiates the Ollama API.
* *
@ -76,8 +79,7 @@ public class OllamaAPI {
responseString = EntityUtils.toString(responseEntity, "UTF-8"); responseString = EntityUtils.toString(responseEntity, "UTF-8");
} }
if (statusCode == 200) { if (statusCode == 200) {
Models m = new Gson().fromJson(responseString, Models.class); return objectMapper.readValue(responseString, ListModelsResponse.class).getModels();
return m.getModels();
} else { } else {
throw new OllamaBaseException(statusCode + " - " + responseString); throw new OllamaBaseException(statusCode + " - " + responseString);
} }
@ -109,7 +111,7 @@ public class OllamaAPI {
responseString = EntityUtils.toString(responseEntity, "UTF-8"); responseString = EntityUtils.toString(responseEntity, "UTF-8");
} }
if (statusCode == 200) { if (statusCode == 200) {
return new Gson().fromJson(responseString, ModelDetail.class); return objectMapper.readValue(responseString, ModelDetail.class);
} else { } else {
throw new OllamaBaseException(statusCode + " - " + responseString); throw new OllamaBaseException(statusCode + " - " + responseString);
} }
@ -234,7 +236,6 @@ public class OllamaAPI {
* @throws IOException * @throws IOException
*/ */
public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException { public String ask(String ollamaModelType, String promptText) throws OllamaBaseException, IOException {
Gson gson = new Gson();
OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText); OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(ollamaModelType, promptText);
URL obj = new URL(this.host + "/api/generate"); URL obj = new URL(this.host + "/api/generate");
HttpURLConnection con = (HttpURLConnection) obj.openConnection(); HttpURLConnection con = (HttpURLConnection) obj.openConnection();
@ -250,7 +251,7 @@ public class OllamaAPI {
String inputLine; String inputLine;
StringBuilder response = new StringBuilder(); StringBuilder response = new StringBuilder();
while ((inputLine = in.readLine()) != null) { while ((inputLine = in.readLine()) != null) {
OllamaResponseModel ollamaResponseModel = gson.fromJson(inputLine, OllamaResponseModel.class); OllamaResponseModel ollamaResponseModel = objectMapper.readValue(inputLine, OllamaResponseModel.class);
if (!ollamaResponseModel.getDone()) { if (!ollamaResponseModel.getDone()) {
response.append(ollamaResponseModel.getResponse()); response.append(ollamaResponseModel.getResponse());
} }

View File

@ -2,7 +2,7 @@ package io.github.amithkoujalgi.ollama4j.core.models;
import java.util.List; import java.util.List;
public class Models { public class ListModelsResponse {
private List<Model> models; private List<Model> models;
public List<Model> getModels() { public List<Model> getModels() {

View File

@ -1,6 +1,7 @@
package io.github.amithkoujalgi.ollama4j.core.models; package io.github.amithkoujalgi.ollama4j.core.models;
import com.google.gson.GsonBuilder; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
public class ModelDetail { public class ModelDetail {
private String license, modelfile, parameters, template; private String license, modelfile, parameters, template;
@ -39,6 +40,13 @@ public class ModelDetail {
@Override @Override
public String toString() { public String toString() {
return new GsonBuilder().setPrettyPrinting().create().toJson(this); try {
return new ObjectMapper()
.writer()
.withDefaultPrettyPrinter()
.writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
} }
} }

View File

@ -1,38 +1,42 @@
package io.github.amithkoujalgi.ollama4j.core.models; package io.github.amithkoujalgi.ollama4j.core.models;
import com.google.gson.Gson; import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.net.HttpURLConnection; import java.net.HttpURLConnection;
import java.util.LinkedList;
import java.util.Queue;
@SuppressWarnings("DuplicatedCode")
public class OllamaAsyncResultCallback extends Thread { public class OllamaAsyncResultCallback extends Thread {
private final HttpURLConnection connection; private final HttpURLConnection connection;
private String result; private String result;
private boolean isDone; private boolean isDone;
private final ObjectMapper objectMapper = new ObjectMapper();
private final Queue<String> queue = new LinkedList<>();
public OllamaAsyncResultCallback(HttpURLConnection connection) { public OllamaAsyncResultCallback(HttpURLConnection connection) {
this.connection = connection; this.connection = connection;
this.isDone = false; this.isDone = false;
this.result = ""; this.result = "";
this.queue.add("");
} }
@Override @Override
public void run() { public void run() {
Gson gson = new Gson();
int responseCode = 0; int responseCode = 0;
try { try {
responseCode = this.connection.getResponseCode(); responseCode = this.connection.getResponseCode();
if (responseCode == HttpURLConnection.HTTP_OK) { if (responseCode == HttpURLConnection.HTTP_OK) {
try (BufferedReader in = try (BufferedReader in = new BufferedReader(new InputStreamReader(this.connection.getInputStream()))) {
new BufferedReader(new InputStreamReader(this.connection.getInputStream()))) {
String inputLine; String inputLine;
StringBuilder response = new StringBuilder(); StringBuilder response = new StringBuilder();
while ((inputLine = in.readLine()) != null) { while ((inputLine = in.readLine()) != null) {
OllamaResponseModel ollamaResponseModel = OllamaResponseModel ollamaResponseModel = objectMapper.readValue(inputLine, OllamaResponseModel.class);
gson.fromJson(inputLine, OllamaResponseModel.class); queue.add(ollamaResponseModel.getResponse());
if (!ollamaResponseModel.getDone()) { if (!ollamaResponseModel.getDone()) {
response.append(ollamaResponseModel.getResponse()); response.append(ollamaResponseModel.getResponse());
} }
@ -42,8 +46,7 @@ public class OllamaAsyncResultCallback extends Thread {
this.result = response.toString(); this.result = response.toString();
} }
} else { } else {
throw new OllamaBaseException( throw new OllamaBaseException(connection.getResponseCode() + " - " + connection.getResponseMessage());
connection.getResponseCode() + " - " + connection.getResponseMessage());
} }
} catch (IOException | OllamaBaseException e) { } catch (IOException | OllamaBaseException e) {
this.isDone = true; this.isDone = true;
@ -55,7 +58,15 @@ public class OllamaAsyncResultCallback extends Thread {
return isDone; return isDone;
} }
/**
* Returns the final response when the execution completes. Does not return intermediate results.
* @return response text
*/
public String getResponse() { public String getResponse() {
return result; return result;
} }
public Queue<String> getStream() {
return queue;
}
} }

View File

@ -1,6 +1,8 @@
package io.github.amithkoujalgi.ollama4j.core.models; package io.github.amithkoujalgi.ollama4j.core.models;
import com.google.gson.Gson;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
public class OllamaRequestModel { public class OllamaRequestModel {
private String model; private String model;
@ -29,6 +31,13 @@ public class OllamaRequestModel {
@Override @Override
public String toString() { public String toString() {
return new Gson().toJson(this); try {
return new ObjectMapper()
.writer()
.withDefaultPrettyPrinter()
.writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
} }
} }