forked from Mirror/ollama4j
		
	Compare commits
	
		
			13 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 2a887f5015 | ||
|   | 7e3dddf1bb | ||
|   | fe95a7df2a | ||
|   | 98f6a30c6b | ||
|   | 00288053bf | ||
|   | 6a7feb98bd | ||
|   | 770d511067 | ||
|   | b57fc1f818 | ||
|   | 01c5a8f07f | ||
|   | 243b8a3747 | ||
|   | 987fce7f07 | ||
|   | 657593be09 | ||
|   | 0afba7e3e3 | 
							
								
								
									
										4
									
								
								.github/workflows/maven-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/maven-publish.yml
									
									
									
									
										vendored
									
									
								
							| @@ -49,6 +49,10 @@ jobs: | |||||||
|             ${{ runner.os }}-maven- |             ${{ runner.os }}-maven- | ||||||
|       - name: Build |       - name: Build | ||||||
|         run: mvn -B -ntp clean install |         run: mvn -B -ntp clean install | ||||||
|  |       - name: Upload coverage reports to Codecov | ||||||
|  |         uses: codecov/codecov-action@v3 | ||||||
|  |         env: | ||||||
|  |           CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} | ||||||
|       - name: Publish to GitHub Packages Apache Maven |       - name: Publish to GitHub Packages Apache Maven | ||||||
|         #        if: > |         #        if: > | ||||||
|         #          github.event_name != 'pull_request' && |         #          github.event_name != 'pull_request' && | ||||||
|   | |||||||
							
								
								
									
										56
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										56
									
								
								README.md
									
									
									
									
									
								
							| @@ -2,8 +2,33 @@ | |||||||
|  |  | ||||||
| <img src='https://raw.githubusercontent.com/amithkoujalgi/ollama4j/65a9d526150da8fcd98e2af6a164f055572bf722/ollama4j.jpeg' width='100' alt="ollama4j-icon"> | <img src='https://raw.githubusercontent.com/amithkoujalgi/ollama4j/65a9d526150da8fcd98e2af6a164f055572bf722/ollama4j.jpeg' width='100' alt="ollama4j-icon"> | ||||||
|  |  | ||||||
| A Java library (wrapper/binding) | A Java library (wrapper/binding) for [Ollama](https://ollama.ai/) server. | ||||||
| for [Ollama](https://github.com/jmorganca/ollama/blob/main/docs/api.md) APIs. |  | ||||||
|  | Find more details on the [website](https://amithkoujalgi.github.io/ollama4j/). | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ## Table of Contents | ||||||
|  |  | ||||||
|  | - [How does it work?](#how-does-it-work) | ||||||
|  | - [Requirements](#requirements) | ||||||
|  | - [Installation](#installation) | ||||||
|  | - [API Spec](#api-spec) | ||||||
|  | - [Demo APIs](#try-out-the-apis-with-ollama-server) | ||||||
|  | - [Development](#development) | ||||||
|  | - [Contributions](#get-involved) | ||||||
|  | - [References](#references) | ||||||
|  |  | ||||||
|  | #### How does it work? | ||||||
|  |  | ||||||
| ```mermaid | ```mermaid | ||||||
|   flowchart LR |   flowchart LR | ||||||
| @@ -17,26 +42,6 @@ for [Ollama](https://github.com/jmorganca/ollama/blob/main/docs/api.md) APIs. | |||||||
|     end |     end | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ## Table of Contents |  | ||||||
|  |  | ||||||
| - [Requirements](#requirements) |  | ||||||
| - [Installation](#installation) |  | ||||||
| - [API Spec](#api-spec) |  | ||||||
| - [Demo APIs](#try-out-the-apis-with-ollama-server) |  | ||||||
| - [Development](#development) |  | ||||||
| - [Contributions](#get-involved) |  | ||||||
|  |  | ||||||
| #### Requirements | #### Requirements | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -76,7 +81,7 @@ Latest release: | |||||||
|  |  | ||||||
| #### API Spec | #### API Spec | ||||||
|  |  | ||||||
| Find the full `Javadoc` (API specifications) [here](https://amithkoujalgi.github.io/ollama4j/). | Find the full API specifications on the [website](https://amithkoujalgi.github.io/ollama4j/). | ||||||
|  |  | ||||||
| #### Development | #### Development | ||||||
|  |  | ||||||
| @@ -117,6 +122,7 @@ Actions CI workflow. | |||||||
| - [x] Use lombok | - [x] Use lombok | ||||||
| - [x] Update request body creation with Java objects | - [x] Update request body creation with Java objects | ||||||
| - [ ] Async APIs for images | - [ ] Async APIs for images | ||||||
|  | - [ ] Add custom headers to requests | ||||||
| - [ ] Add additional params for `ask` APIs such as: | - [ ] Add additional params for `ask` APIs such as: | ||||||
|     - `options`: additional model parameters for the Modelfile such as `temperature` |     - `options`: additional model parameters for the Modelfile such as `temperature` | ||||||
|     - `system`: system prompt to (overrides what is defined in the Modelfile) |     - `system`: system prompt to (overrides what is defined in the Modelfile) | ||||||
| @@ -138,3 +144,7 @@ of contribution is much appreciated. | |||||||
|  |  | ||||||
| The nomenclature and the icon have been adopted from the incredible [Ollama](https://ollama.ai/) | The nomenclature and the icon have been adopted from the incredible [Ollama](https://ollama.ai/) | ||||||
| project. | project. | ||||||
|  |  | ||||||
|  | ### References | ||||||
|  |  | ||||||
|  | - [Ollama REST APIs](https://github.com/jmorganca/ollama/blob/main/docs/api.md) | ||||||
							
								
								
									
										24
									
								
								docs/docs/apis-extras/basic-auth.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								docs/docs/apis-extras/basic-auth.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | |||||||
|  | --- | ||||||
|  | sidebar_position: 2 | ||||||
|  | --- | ||||||
|  |  | ||||||
|  | # Set Basic Authentication | ||||||
|  |  | ||||||
|  | This API lets you set the basic authentication for the Ollama client. This would help in scenarios where | ||||||
|  | Ollama server would be setup behind a gateway/reverse proxy with basic auth. | ||||||
|  |  | ||||||
|  | After configuring basic authentication, all subsequent requests will include the Basic Auth header. | ||||||
|  |  | ||||||
|  | ```java | ||||||
|  | public class Main { | ||||||
|  |  | ||||||
|  |     public static void main(String[] args) { | ||||||
|  |  | ||||||
|  |         String host = "http://localhost:11434/"; | ||||||
|  |  | ||||||
|  |         OllamaAPI ollamaAPI = new OllamaAPI(host); | ||||||
|  |  | ||||||
|  |         ollamaAPI.setBasicAuth("username", "password"); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | ``` | ||||||
							
								
								
									
										49
									
								
								pom.xml
									
									
									
									
									
								
							
							
						
						
									
										49
									
								
								pom.xml
									
									
									
									
									
								
							| @@ -4,7 +4,7 @@ | |||||||
|  |  | ||||||
|     <groupId>io.github.amithkoujalgi</groupId> |     <groupId>io.github.amithkoujalgi</groupId> | ||||||
|     <artifactId>ollama4j</artifactId> |     <artifactId>ollama4j</artifactId> | ||||||
|     <version>1.0.31</version> |     <version>1.0.35</version> | ||||||
|  |  | ||||||
|   <name>Ollama4j</name> |   <name>Ollama4j</name> | ||||||
|   <description>Java library for interacting with Ollama API.</description> |   <description>Java library for interacting with Ollama API.</description> | ||||||
| @@ -39,7 +39,7 @@ | |||||||
|     <connection>scm:git:git@github.com:amithkoujalgi/ollama4j.git</connection> |     <connection>scm:git:git@github.com:amithkoujalgi/ollama4j.git</connection> | ||||||
|     <developerConnection>scm:git:https://github.com/amithkoujalgi/ollama4j.git</developerConnection> |     <developerConnection>scm:git:https://github.com/amithkoujalgi/ollama4j.git</developerConnection> | ||||||
|     <url>https://github.com/amithkoujalgi/ollama4j</url> |     <url>https://github.com/amithkoujalgi/ollama4j</url> | ||||||
|         <tag>v1.0.31</tag> |     <tag>v1.0.35</tag> | ||||||
|   </scm> |   </scm> | ||||||
|  |  | ||||||
|   <build> |   <build> | ||||||
| @@ -154,7 +154,7 @@ | |||||||
|     <dependency> |     <dependency> | ||||||
|       <groupId>ch.qos.logback</groupId> |       <groupId>ch.qos.logback</groupId> | ||||||
|       <artifactId>logback-classic</artifactId> |       <artifactId>logback-classic</artifactId> | ||||||
|             <version>1.3.11</version> |       <version>1.4.12</version> | ||||||
|       <scope>test</scope> |       <scope>test</scope> | ||||||
|     </dependency> |     </dependency> | ||||||
|     <dependency> |     <dependency> | ||||||
| @@ -198,6 +198,29 @@ | |||||||
|       <activation> |       <activation> | ||||||
|         <activeByDefault>true</activeByDefault> |         <activeByDefault>true</activeByDefault> | ||||||
|       </activation> |       </activation> | ||||||
|  |       <build> | ||||||
|  |         <plugins> | ||||||
|  |           <plugin> | ||||||
|  |             <groupId>org.jacoco</groupId> | ||||||
|  |             <artifactId>jacoco-maven-plugin</artifactId> | ||||||
|  |             <version>0.8.7</version> | ||||||
|  |             <executions> | ||||||
|  |               <execution> | ||||||
|  |                 <goals> | ||||||
|  |                   <goal>prepare-agent</goal> | ||||||
|  |                 </goals> | ||||||
|  |               </execution> | ||||||
|  |               <execution> | ||||||
|  |                 <id>report</id> | ||||||
|  |                 <phase>test</phase> | ||||||
|  |                 <goals> | ||||||
|  |                   <goal>report</goal> | ||||||
|  |                 </goals> | ||||||
|  |               </execution> | ||||||
|  |             </executions> | ||||||
|  |           </plugin> | ||||||
|  |         </plugins> | ||||||
|  |       </build> | ||||||
|     </profile> |     </profile> | ||||||
|     <profile> |     <profile> | ||||||
|       <id>integration-tests</id> |       <id>integration-tests</id> | ||||||
| @@ -249,6 +272,26 @@ | |||||||
|               <autoReleaseAfterClose>true</autoReleaseAfterClose> |               <autoReleaseAfterClose>true</autoReleaseAfterClose> | ||||||
|             </configuration> |             </configuration> | ||||||
|           </plugin> |           </plugin> | ||||||
|  |  | ||||||
|  |           <plugin> | ||||||
|  |             <groupId>org.jacoco</groupId> | ||||||
|  |             <artifactId>jacoco-maven-plugin</artifactId> | ||||||
|  |             <version>0.8.7</version> | ||||||
|  |             <executions> | ||||||
|  |               <execution> | ||||||
|  |                 <goals> | ||||||
|  |                   <goal>prepare-agent</goal> | ||||||
|  |                 </goals> | ||||||
|  |               </execution> | ||||||
|  |               <execution> | ||||||
|  |                 <id>report</id> | ||||||
|  |                 <phase>test</phase> | ||||||
|  |                 <goals> | ||||||
|  |                   <goal>report</goal> | ||||||
|  |                 </goals> | ||||||
|  |               </execution> | ||||||
|  |             </executions> | ||||||
|  |           </plugin> | ||||||
|         </plugins> |         </plugins> | ||||||
|       </build> |       </build> | ||||||
|     </profile> |     </profile> | ||||||
|   | |||||||
| @@ -37,8 +37,7 @@ public class OllamaAPI { | |||||||
|   private final String host; |   private final String host; | ||||||
|   private long requestTimeoutSeconds = 3; |   private long requestTimeoutSeconds = 3; | ||||||
|   private boolean verbose = true; |   private boolean verbose = true; | ||||||
|   private String username; |   private BasicAuth basicAuth; | ||||||
|   private String password; |  | ||||||
|  |  | ||||||
|   /** |   /** | ||||||
|    * Instantiates the Ollama API. |    * Instantiates the Ollama API. | ||||||
| @@ -53,6 +52,11 @@ public class OllamaAPI { | |||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   /** | ||||||
|  |    * Set request timeout in seconds. Default is 3 seconds. | ||||||
|  |    * | ||||||
|  |    * @param requestTimeoutSeconds the request timeout in seconds | ||||||
|  |    */ | ||||||
|   public void setRequestTimeoutSeconds(long requestTimeoutSeconds) { |   public void setRequestTimeoutSeconds(long requestTimeoutSeconds) { | ||||||
|     this.requestTimeoutSeconds = requestTimeoutSeconds; |     this.requestTimeoutSeconds = requestTimeoutSeconds; | ||||||
|   } |   } | ||||||
| @@ -67,11 +71,13 @@ public class OllamaAPI { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   /** |   /** | ||||||
|  |    * Set basic authentication for accessing Ollama server that's behind a reverse-proxy/gateway. | ||||||
|    * |    * | ||||||
|  |    * @param username the username | ||||||
|  |    * @param password the password | ||||||
|    */ |    */ | ||||||
|   public void setBasicAuth(String username, String password) { |   public void setBasicAuth(String username, String password) { | ||||||
|     this.username = username; |     this.basicAuth = new BasicAuth(username, password); | ||||||
|     this.password = password; |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   /** |   /** | ||||||
| @@ -85,11 +91,9 @@ public class OllamaAPI { | |||||||
|     HttpRequest httpRequest = null; |     HttpRequest httpRequest = null; | ||||||
|     try { |     try { | ||||||
|       httpRequest = |       httpRequest = | ||||||
|           HttpRequest.newBuilder() |           getRequestBuilderDefault(new URI(url)) | ||||||
|               .uri(new URI(url)) |  | ||||||
|               .header("Accept", "application/json") |               .header("Accept", "application/json") | ||||||
|               .header("Content-type", "application/json") |               .header("Content-type", "application/json") | ||||||
|               .timeout(Duration.ofSeconds(requestTimeoutSeconds)) |  | ||||||
|               .GET() |               .GET() | ||||||
|               .build(); |               .build(); | ||||||
|     } catch (URISyntaxException e) { |     } catch (URISyntaxException e) { | ||||||
| @@ -117,11 +121,9 @@ public class OllamaAPI { | |||||||
|     String url = this.host + "/api/tags"; |     String url = this.host + "/api/tags"; | ||||||
|     HttpClient httpClient = HttpClient.newHttpClient(); |     HttpClient httpClient = HttpClient.newHttpClient(); | ||||||
|     HttpRequest httpRequest = |     HttpRequest httpRequest = | ||||||
|         HttpRequest.newBuilder() |         getRequestBuilderDefault(new URI(url)) | ||||||
|             .uri(new URI(url)) |  | ||||||
|             .header("Accept", "application/json") |             .header("Accept", "application/json") | ||||||
|             .header("Content-type", "application/json") |             .header("Content-type", "application/json") | ||||||
|             .timeout(Duration.ofSeconds(requestTimeoutSeconds)) |  | ||||||
|             .GET() |             .GET() | ||||||
|             .build(); |             .build(); | ||||||
|     HttpResponse<String> response = |     HttpResponse<String> response = | ||||||
| @@ -148,12 +150,10 @@ public class OllamaAPI { | |||||||
|     String url = this.host + "/api/pull"; |     String url = this.host + "/api/pull"; | ||||||
|     String jsonData = new ModelRequest(modelName).toString(); |     String jsonData = new ModelRequest(modelName).toString(); | ||||||
|     HttpRequest request = |     HttpRequest request = | ||||||
|         HttpRequest.newBuilder() |         getRequestBuilderDefault(new URI(url)) | ||||||
|             .uri(new URI(url)) |  | ||||||
|             .POST(HttpRequest.BodyPublishers.ofString(jsonData)) |             .POST(HttpRequest.BodyPublishers.ofString(jsonData)) | ||||||
|             .header("Accept", "application/json") |             .header("Accept", "application/json") | ||||||
|             .header("Content-type", "application/json") |             .header("Content-type", "application/json") | ||||||
|             .timeout(Duration.ofSeconds(requestTimeoutSeconds)) |  | ||||||
|             .build(); |             .build(); | ||||||
|     HttpClient client = HttpClient.newHttpClient(); |     HttpClient client = HttpClient.newHttpClient(); | ||||||
|     HttpResponse<InputStream> response = |     HttpResponse<InputStream> response = | ||||||
| @@ -184,15 +184,13 @@ public class OllamaAPI { | |||||||
|    * @return the model details |    * @return the model details | ||||||
|    */ |    */ | ||||||
|   public ModelDetail getModelDetails(String modelName) |   public ModelDetail getModelDetails(String modelName) | ||||||
|       throws IOException, OllamaBaseException, InterruptedException { |       throws IOException, OllamaBaseException, InterruptedException, URISyntaxException { | ||||||
|     String url = this.host + "/api/show"; |     String url = this.host + "/api/show"; | ||||||
|     String jsonData = new ModelRequest(modelName).toString(); |     String jsonData = new ModelRequest(modelName).toString(); | ||||||
|     HttpRequest request = |     HttpRequest request = | ||||||
|         HttpRequest.newBuilder() |         getRequestBuilderDefault(new URI(url)) | ||||||
|             .uri(URI.create(url)) |  | ||||||
|             .header("Accept", "application/json") |             .header("Accept", "application/json") | ||||||
|             .header("Content-type", "application/json") |             .header("Content-type", "application/json") | ||||||
|             .timeout(Duration.ofSeconds(requestTimeoutSeconds)) |  | ||||||
|             .POST(HttpRequest.BodyPublishers.ofString(jsonData)) |             .POST(HttpRequest.BodyPublishers.ofString(jsonData)) | ||||||
|             .build(); |             .build(); | ||||||
|     HttpClient client = HttpClient.newHttpClient(); |     HttpClient client = HttpClient.newHttpClient(); | ||||||
| @@ -214,15 +212,13 @@ public class OllamaAPI { | |||||||
|    * @param modelFilePath the path to model file that exists on the Ollama server. |    * @param modelFilePath the path to model file that exists on the Ollama server. | ||||||
|    */ |    */ | ||||||
|   public void createModelWithFilePath(String modelName, String modelFilePath) |   public void createModelWithFilePath(String modelName, String modelFilePath) | ||||||
|       throws IOException, InterruptedException, OllamaBaseException { |       throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { | ||||||
|     String url = this.host + "/api/create"; |     String url = this.host + "/api/create"; | ||||||
|     String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString(); |     String jsonData = new CustomModelFilePathRequest(modelName, modelFilePath).toString(); | ||||||
|     HttpRequest request = |     HttpRequest request = | ||||||
|         HttpRequest.newBuilder() |         getRequestBuilderDefault(new URI(url)) | ||||||
|             .uri(URI.create(url)) |  | ||||||
|             .header("Accept", "application/json") |             .header("Accept", "application/json") | ||||||
|             .header("Content-Type", "application/json") |             .header("Content-Type", "application/json") | ||||||
|             .timeout(Duration.ofSeconds(requestTimeoutSeconds)) |  | ||||||
|             .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) |             .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) | ||||||
|             .build(); |             .build(); | ||||||
|     HttpClient client = HttpClient.newHttpClient(); |     HttpClient client = HttpClient.newHttpClient(); | ||||||
| @@ -250,15 +246,13 @@ public class OllamaAPI { | |||||||
|    * @param modelFileContents the path to model file that exists on the Ollama server. |    * @param modelFileContents the path to model file that exists on the Ollama server. | ||||||
|    */ |    */ | ||||||
|   public void createModelWithModelFileContents(String modelName, String modelFileContents) |   public void createModelWithModelFileContents(String modelName, String modelFileContents) | ||||||
|       throws IOException, InterruptedException, OllamaBaseException { |       throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { | ||||||
|     String url = this.host + "/api/create"; |     String url = this.host + "/api/create"; | ||||||
|     String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString(); |     String jsonData = new CustomModelFileContentsRequest(modelName, modelFileContents).toString(); | ||||||
|     HttpRequest request = |     HttpRequest request = | ||||||
|         HttpRequest.newBuilder() |         getRequestBuilderDefault(new URI(url)) | ||||||
|             .uri(URI.create(url)) |  | ||||||
|             .header("Accept", "application/json") |             .header("Accept", "application/json") | ||||||
|             .header("Content-Type", "application/json") |             .header("Content-Type", "application/json") | ||||||
|             .timeout(Duration.ofSeconds(requestTimeoutSeconds)) |  | ||||||
|             .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) |             .POST(HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) | ||||||
|             .build(); |             .build(); | ||||||
|     HttpClient client = HttpClient.newHttpClient(); |     HttpClient client = HttpClient.newHttpClient(); | ||||||
| @@ -280,20 +274,17 @@ public class OllamaAPI { | |||||||
|    * Delete a model from Ollama server. |    * Delete a model from Ollama server. | ||||||
|    * |    * | ||||||
|    * @param modelName the name of the model to be deleted. |    * @param modelName the name of the model to be deleted. | ||||||
|    * @param ignoreIfNotPresent - ignore errors if the specified model is not present on Ollama |    * @param ignoreIfNotPresent ignore errors if the specified model is not present on Ollama server. | ||||||
|    *     server. |  | ||||||
|    */ |    */ | ||||||
|   public void deleteModel(String modelName, boolean ignoreIfNotPresent) |   public void deleteModel(String modelName, boolean ignoreIfNotPresent) | ||||||
|       throws IOException, InterruptedException, OllamaBaseException { |       throws IOException, InterruptedException, OllamaBaseException, URISyntaxException { | ||||||
|     String url = this.host + "/api/delete"; |     String url = this.host + "/api/delete"; | ||||||
|     String jsonData = new ModelRequest(modelName).toString(); |     String jsonData = new ModelRequest(modelName).toString(); | ||||||
|     HttpRequest request = |     HttpRequest request = | ||||||
|         HttpRequest.newBuilder() |         getRequestBuilderDefault(new URI(url)) | ||||||
|             .uri(URI.create(url)) |  | ||||||
|             .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) |             .method("DELETE", HttpRequest.BodyPublishers.ofString(jsonData, StandardCharsets.UTF_8)) | ||||||
|             .header("Accept", "application/json") |             .header("Accept", "application/json") | ||||||
|             .header("Content-type", "application/json") |             .header("Content-type", "application/json") | ||||||
|             .timeout(Duration.ofSeconds(requestTimeoutSeconds)) |  | ||||||
|             .build(); |             .build(); | ||||||
|     HttpClient client = HttpClient.newHttpClient(); |     HttpClient client = HttpClient.newHttpClient(); | ||||||
|     HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); |     HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString()); | ||||||
| @@ -319,7 +310,8 @@ public class OllamaAPI { | |||||||
|     URI uri = URI.create(this.host + "/api/embeddings"); |     URI uri = URI.create(this.host + "/api/embeddings"); | ||||||
|     String jsonData = new ModelEmbeddingsRequest(model, prompt).toString(); |     String jsonData = new ModelEmbeddingsRequest(model, prompt).toString(); | ||||||
|     HttpClient httpClient = HttpClient.newHttpClient(); |     HttpClient httpClient = HttpClient.newHttpClient(); | ||||||
|     HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri) |     HttpRequest.Builder requestBuilder = | ||||||
|  |         getRequestBuilderDefault(uri) | ||||||
|             .header("Accept", "application/json") |             .header("Accept", "application/json") | ||||||
|             .POST(HttpRequest.BodyPublishers.ofString(jsonData)); |             .POST(HttpRequest.BodyPublishers.ofString(jsonData)); | ||||||
|     HttpRequest request = requestBuilder.build(); |     HttpRequest request = requestBuilder.build(); | ||||||
| @@ -339,12 +331,12 @@ public class OllamaAPI { | |||||||
|    * Ask a question to a model running on Ollama server. This is a sync/blocking call. |    * Ask a question to a model running on Ollama server. This is a sync/blocking call. | ||||||
|    * |    * | ||||||
|    * @param model the ollama model to ask the question to |    * @param model the ollama model to ask the question to | ||||||
|    * @param promptText the prompt/question text |    * @param prompt the prompt/question text | ||||||
|    * @return OllamaResult - that includes response text and time taken for response |    * @return OllamaResult that includes response text and time taken for response | ||||||
|    */ |    */ | ||||||
|   public OllamaResult ask(String model, String promptText) |   public OllamaResult ask(String model, String prompt) | ||||||
|       throws OllamaBaseException, IOException, InterruptedException { |       throws OllamaBaseException, IOException, InterruptedException { | ||||||
|     OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText); |     OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); | ||||||
|     return askSync(ollamaRequestModel); |     return askSync(ollamaRequestModel); | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @@ -354,15 +346,16 @@ public class OllamaAPI { | |||||||
|    * async/non-blocking call. |    * async/non-blocking call. | ||||||
|    * |    * | ||||||
|    * @param model the ollama model to ask the question to |    * @param model the ollama model to ask the question to | ||||||
|    * @param promptText the prompt/question text |    * @param prompt the prompt/question text | ||||||
|    * @return the ollama async result callback handle |    * @return the ollama async result callback handle | ||||||
|    */ |    */ | ||||||
|   public OllamaAsyncResultCallback askAsync(String model, String promptText) { |   public OllamaAsyncResultCallback askAsync(String model, String prompt) { | ||||||
|     OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText); |     OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt); | ||||||
|     HttpClient httpClient = HttpClient.newHttpClient(); |  | ||||||
|     URI uri = URI.create(this.host + "/api/generate"); |     URI uri = URI.create(this.host + "/api/generate"); | ||||||
|     OllamaAsyncResultCallback ollamaAsyncResultCallback = |     OllamaAsyncResultCallback ollamaAsyncResultCallback = | ||||||
|         new OllamaAsyncResultCallback(httpClient, uri, ollamaRequestModel, requestTimeoutSeconds); |         new OllamaAsyncResultCallback( | ||||||
|  |             getRequestBuilderDefault(uri), ollamaRequestModel, requestTimeoutSeconds); | ||||||
|     ollamaAsyncResultCallback.start(); |     ollamaAsyncResultCallback.start(); | ||||||
|     return ollamaAsyncResultCallback; |     return ollamaAsyncResultCallback; | ||||||
|   } |   } | ||||||
| @@ -372,17 +365,17 @@ public class OllamaAPI { | |||||||
|    * sync/blocking call. |    * sync/blocking call. | ||||||
|    * |    * | ||||||
|    * @param model the ollama model to ask the question to |    * @param model the ollama model to ask the question to | ||||||
|    * @param promptText the prompt/question text |    * @param prompt the prompt/question text | ||||||
|    * @param imageFiles the list of image files to use for the question |    * @param imageFiles the list of image files to use for the question | ||||||
|    * @return OllamaResult - that includes response text and time taken for response |    * @return OllamaResult that includes response text and time taken for response | ||||||
|    */ |    */ | ||||||
|   public OllamaResult askWithImageFiles(String model, String promptText, List<File> imageFiles) |   public OllamaResult askWithImageFiles(String model, String prompt, List<File> imageFiles) | ||||||
|       throws OllamaBaseException, IOException, InterruptedException { |       throws OllamaBaseException, IOException, InterruptedException { | ||||||
|     List<String> images = new ArrayList<>(); |     List<String> images = new ArrayList<>(); | ||||||
|     for (File imageFile : imageFiles) { |     for (File imageFile : imageFiles) { | ||||||
|       images.add(encodeFileToBase64(imageFile)); |       images.add(encodeFileToBase64(imageFile)); | ||||||
|     } |     } | ||||||
|     OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText, images); |     OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); | ||||||
|     return askSync(ollamaRequestModel); |     return askSync(ollamaRequestModel); | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @@ -391,17 +384,17 @@ public class OllamaAPI { | |||||||
|    * sync/blocking call. |    * sync/blocking call. | ||||||
|    * |    * | ||||||
|    * @param model the ollama model to ask the question to |    * @param model the ollama model to ask the question to | ||||||
|    * @param promptText the prompt/question text |    * @param prompt the prompt/question text | ||||||
|    * @param imageURLs the list of image URLs to use for the question |    * @param imageURLs the list of image URLs to use for the question | ||||||
|    * @return OllamaResult - that includes response text and time taken for response |    * @return OllamaResult that includes response text and time taken for response | ||||||
|    */ |    */ | ||||||
|   public OllamaResult askWithImageURLs(String model, String promptText, List<String> imageURLs) |   public OllamaResult askWithImageURLs(String model, String prompt, List<String> imageURLs) | ||||||
|       throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { |       throws OllamaBaseException, IOException, InterruptedException, URISyntaxException { | ||||||
|     List<String> images = new ArrayList<>(); |     List<String> images = new ArrayList<>(); | ||||||
|     for (String imageURL : imageURLs) { |     for (String imageURL : imageURLs) { | ||||||
|       images.add(encodeByteArrayToBase64(loadImageBytesFromUrl(imageURL))); |       images.add(encodeByteArrayToBase64(loadImageBytesFromUrl(imageURL))); | ||||||
|     } |     } | ||||||
|     OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, promptText, images); |     OllamaRequestModel ollamaRequestModel = new OllamaRequestModel(model, prompt, images); | ||||||
|     return askSync(ollamaRequestModel); |     return askSync(ollamaRequestModel); | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @@ -432,7 +425,8 @@ public class OllamaAPI { | |||||||
|     long startTime = System.currentTimeMillis(); |     long startTime = System.currentTimeMillis(); | ||||||
|     HttpClient httpClient = HttpClient.newHttpClient(); |     HttpClient httpClient = HttpClient.newHttpClient(); | ||||||
|     URI uri = URI.create(this.host + "/api/generate"); |     URI uri = URI.create(this.host + "/api/generate"); | ||||||
|     HttpRequest.Builder requestBuilder = getRequestBuilderDefault(uri) |     HttpRequest.Builder requestBuilder = | ||||||
|  |         getRequestBuilderDefault(uri) | ||||||
|             .POST( |             .POST( | ||||||
|                 HttpRequest.BodyPublishers.ofString( |                 HttpRequest.BodyPublishers.ofString( | ||||||
|                     Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))); |                     Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))); | ||||||
| @@ -455,9 +449,10 @@ public class OllamaAPI { | |||||||
|         } else if (statusCode == 401) { |         } else if (statusCode == 401) { | ||||||
|           logger.warn("Status code: 401 (Unauthorized)"); |           logger.warn("Status code: 401 (Unauthorized)"); | ||||||
|           OllamaErrorResponseModel ollamaResponseModel = |           OllamaErrorResponseModel ollamaResponseModel = | ||||||
|                   Utils.getObjectMapper().readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); |               Utils.getObjectMapper() | ||||||
|  |                   .readValue("{\"error\":\"Unauthorized\"}", OllamaErrorResponseModel.class); | ||||||
|           responseBuffer.append(ollamaResponseModel.getError()); |           responseBuffer.append(ollamaResponseModel.getError()); | ||||||
|         }else { |         } else { | ||||||
|           OllamaResponseModel ollamaResponseModel = |           OllamaResponseModel ollamaResponseModel = | ||||||
|               Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); |               Utils.getObjectMapper().readValue(line, OllamaResponseModel.class); | ||||||
|           if (!ollamaResponseModel.isDone()) { |           if (!ollamaResponseModel.isDone()) { | ||||||
| @@ -467,7 +462,7 @@ public class OllamaAPI { | |||||||
|       } |       } | ||||||
|     } |     } | ||||||
|     if (statusCode != 200) { |     if (statusCode != 200) { | ||||||
|       logger.error("Status code " + statusCode + " instead 200"); |       logger.error("Status code " + statusCode); | ||||||
|       throw new OllamaBaseException(responseBuffer.toString()); |       throw new OllamaBaseException(responseBuffer.toString()); | ||||||
|     } else { |     } else { | ||||||
|       long endTime = System.currentTimeMillis(); |       long endTime = System.currentTimeMillis(); | ||||||
| @@ -476,35 +471,38 @@ public class OllamaAPI { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   /** |   /** | ||||||
|  |    * Get default request builder. | ||||||
|    * |    * | ||||||
|  |    * @param uri URI to get a HttpRequest.Builder | ||||||
|  |    * @return HttpRequest.Builder | ||||||
|    */ |    */ | ||||||
|   private HttpRequest.Builder getRequestBuilderDefault(URI uri) { |   private HttpRequest.Builder getRequestBuilderDefault(URI uri) { | ||||||
|     HttpRequest.Builder requestBuilder = |     HttpRequest.Builder requestBuilder = | ||||||
|         HttpRequest.newBuilder(uri) |         HttpRequest.newBuilder(uri) | ||||||
|             .header("Content-Type", "application/json") |             .header("Content-Type", "application/json") | ||||||
|             .timeout(Duration.ofSeconds(requestTimeoutSeconds)); |             .timeout(Duration.ofSeconds(requestTimeoutSeconds)); | ||||||
|     if (basicAuthCredentialsSet()) { |     if (isBasicAuthCredentialsSet()) { | ||||||
|       requestBuilder.header("Authorization", getBasicAuthHeaderValue()); |       requestBuilder.header("Authorization", getBasicAuthHeaderValue()); | ||||||
|     } |     } | ||||||
|     return requestBuilder; |     return requestBuilder; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   /** |   /** | ||||||
|  |    * Get basic authentication header value. | ||||||
|  |    * | ||||||
|    * @return basic authentication header value (encoded credentials) |    * @return basic authentication header value (encoded credentials) | ||||||
|    */ |    */ | ||||||
|   private String getBasicAuthHeaderValue() { |   private String getBasicAuthHeaderValue() { | ||||||
|     String credentialsToEncode = username + ":" + password; |     String credentialsToEncode = basicAuth.getUsername() + ":" + basicAuth.getPassword(); | ||||||
|     return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); |     return "Basic " + Base64.getEncoder().encodeToString(credentialsToEncode.getBytes()); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   /** |   /** | ||||||
|  |    * Check if Basic Auth credentials set. | ||||||
|  |    * | ||||||
|    * @return true when Basic Auth credentials set |    * @return true when Basic Auth credentials set | ||||||
|    */ |    */ | ||||||
|   private boolean basicAuthCredentialsSet() { |   private boolean isBasicAuthCredentialsSet() { | ||||||
|     if (username != null && password !=  null) { |     return basicAuth != null; | ||||||
|       return true; |  | ||||||
|     } else { |  | ||||||
|       return false; |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -0,0 +1,13 @@ | |||||||
|  | package io.github.amithkoujalgi.ollama4j.core.models; | ||||||
|  |  | ||||||
|  | import lombok.AllArgsConstructor; | ||||||
|  | import lombok.Data; | ||||||
|  | import lombok.NoArgsConstructor; | ||||||
|  |  | ||||||
|  | @Data | ||||||
|  | @NoArgsConstructor | ||||||
|  | @AllArgsConstructor | ||||||
|  | public class BasicAuth { | ||||||
|  |   private String username; | ||||||
|  |   private String password; | ||||||
|  | } | ||||||
| @@ -6,7 +6,6 @@ import java.io.BufferedReader; | |||||||
| import java.io.IOException; | import java.io.IOException; | ||||||
| import java.io.InputStream; | import java.io.InputStream; | ||||||
| import java.io.InputStreamReader; | import java.io.InputStreamReader; | ||||||
| import java.net.URI; |  | ||||||
| import java.net.http.HttpClient; | import java.net.http.HttpClient; | ||||||
| import java.net.http.HttpRequest; | import java.net.http.HttpRequest; | ||||||
| import java.net.http.HttpResponse; | import java.net.http.HttpResponse; | ||||||
| @@ -14,30 +13,44 @@ import java.nio.charset.StandardCharsets; | |||||||
| import java.time.Duration; | import java.time.Duration; | ||||||
| import java.util.LinkedList; | import java.util.LinkedList; | ||||||
| import java.util.Queue; | import java.util.Queue; | ||||||
|  | import lombok.Data; | ||||||
|  | import lombok.EqualsAndHashCode; | ||||||
|  | import lombok.Getter; | ||||||
|  |  | ||||||
|  | @Data | ||||||
|  | @EqualsAndHashCode(callSuper = true) | ||||||
| @SuppressWarnings("unused") | @SuppressWarnings("unused") | ||||||
| public class OllamaAsyncResultCallback extends Thread { | public class OllamaAsyncResultCallback extends Thread { | ||||||
|   private final HttpClient client; |   private final HttpRequest.Builder requestBuilder; | ||||||
|   private final URI uri; |  | ||||||
|   private final OllamaRequestModel ollamaRequestModel; |   private final OllamaRequestModel ollamaRequestModel; | ||||||
|   private final Queue<String> queue = new LinkedList<>(); |   private final Queue<String> queue = new LinkedList<>(); | ||||||
|   private String result; |   private String result; | ||||||
|   private boolean isDone; |   private boolean isDone; | ||||||
|   private boolean succeeded; |  | ||||||
|  |   /** | ||||||
|  |    * -- GETTER -- Returns the status of the request. Indicates if the request was successful or a | ||||||
|  |    * failure. If the request was a failure, the `getResponse()` method will return the error | ||||||
|  |    * message. | ||||||
|  |    */ | ||||||
|  |   @Getter private boolean succeeded; | ||||||
|  |  | ||||||
|   private long requestTimeoutSeconds; |   private long requestTimeoutSeconds; | ||||||
|  |  | ||||||
|   private int httpStatusCode; |   /** | ||||||
|   private long responseTime = 0; |    * -- GETTER -- Returns the HTTP response status code for the request that was made to Ollama | ||||||
|  |    * server. | ||||||
|  |    */ | ||||||
|  |   @Getter private int httpStatusCode; | ||||||
|  |  | ||||||
|  |   /** -- GETTER -- Returns the response time in milliseconds. */ | ||||||
|  |   @Getter private long responseTime = 0; | ||||||
|  |  | ||||||
|   public OllamaAsyncResultCallback( |   public OllamaAsyncResultCallback( | ||||||
|       HttpClient client, |       HttpRequest.Builder requestBuilder, | ||||||
|       URI uri, |  | ||||||
|       OllamaRequestModel ollamaRequestModel, |       OllamaRequestModel ollamaRequestModel, | ||||||
|       long requestTimeoutSeconds) { |       long requestTimeoutSeconds) { | ||||||
|     this.client = client; |     this.requestBuilder = requestBuilder; | ||||||
|     this.ollamaRequestModel = ollamaRequestModel; |     this.ollamaRequestModel = ollamaRequestModel; | ||||||
|     this.uri = uri; |  | ||||||
|     this.isDone = false; |     this.isDone = false; | ||||||
|     this.result = ""; |     this.result = ""; | ||||||
|     this.queue.add(""); |     this.queue.add(""); | ||||||
| @@ -46,10 +59,11 @@ public class OllamaAsyncResultCallback extends Thread { | |||||||
|  |  | ||||||
|   @Override |   @Override | ||||||
|   public void run() { |   public void run() { | ||||||
|  |     HttpClient httpClient = HttpClient.newHttpClient(); | ||||||
|     try { |     try { | ||||||
|       long startTime = System.currentTimeMillis(); |       long startTime = System.currentTimeMillis(); | ||||||
|       HttpRequest request = |       HttpRequest request = | ||||||
|           HttpRequest.newBuilder(uri) |           requestBuilder | ||||||
|               .POST( |               .POST( | ||||||
|                   HttpRequest.BodyPublishers.ofString( |                   HttpRequest.BodyPublishers.ofString( | ||||||
|                       Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) |                       Utils.getObjectMapper().writeValueAsString(ollamaRequestModel))) | ||||||
| @@ -57,7 +71,7 @@ public class OllamaAsyncResultCallback extends Thread { | |||||||
|               .timeout(Duration.ofSeconds(requestTimeoutSeconds)) |               .timeout(Duration.ofSeconds(requestTimeoutSeconds)) | ||||||
|               .build(); |               .build(); | ||||||
|       HttpResponse<InputStream> response = |       HttpResponse<InputStream> response = | ||||||
|           client.send(request, HttpResponse.BodyHandlers.ofInputStream()); |           httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); | ||||||
|       int statusCode = response.statusCode(); |       int statusCode = response.statusCode(); | ||||||
|       this.httpStatusCode = statusCode; |       this.httpStatusCode = statusCode; | ||||||
|  |  | ||||||
| @@ -108,25 +122,6 @@ public class OllamaAsyncResultCallback extends Thread { | |||||||
|     return isDone; |     return isDone; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   /** |  | ||||||
|    * Returns the HTTP response status code for the request that was made to Ollama server. |  | ||||||
|    * |  | ||||||
|    * @return int - the status code for the request |  | ||||||
|    */ |  | ||||||
|   public int getHttpStatusCode() { |  | ||||||
|     return httpStatusCode; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   /** |  | ||||||
|    * Returns the status of the request. Indicates if the request was successful or a failure. If the |  | ||||||
|    * request was a failure, the `getResponse()` method will return the error message. |  | ||||||
|    * |  | ||||||
|    * @return boolean - status |  | ||||||
|    */ |  | ||||||
|   public boolean isSucceeded() { |  | ||||||
|     return succeeded; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   /** |   /** | ||||||
|    * Returns the final response when the execution completes. Does not return intermediate results. |    * Returns the final response when the execution completes. Does not return intermediate results. | ||||||
|    * |    * | ||||||
| @@ -140,15 +135,6 @@ public class OllamaAsyncResultCallback extends Thread { | |||||||
|     return queue; |     return queue; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   /** |  | ||||||
|    * Returns the response time in milliseconds. |  | ||||||
|    * |  | ||||||
|    * @return long - response time in milliseconds. |  | ||||||
|    */ |  | ||||||
|   public long getResponseTime() { |  | ||||||
|     return responseTime; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   public void setRequestTimeoutSeconds(long requestTimeoutSeconds) { |   public void setRequestTimeoutSeconds(long requestTimeoutSeconds) { | ||||||
|     this.requestTimeoutSeconds = requestTimeoutSeconds; |     this.requestTimeoutSeconds = requestTimeoutSeconds; | ||||||
|   } |   } | ||||||
|   | |||||||
| @@ -11,12 +11,13 @@ import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; | |||||||
| import java.io.IOException; | import java.io.IOException; | ||||||
| import java.net.URISyntaxException; | import java.net.URISyntaxException; | ||||||
| import java.util.ArrayList; | import java.util.ArrayList; | ||||||
|  | import java.util.Collections; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
| import org.mockito.Mockito; | import org.mockito.Mockito; | ||||||
|  |  | ||||||
| class TestMockedAPIs { | class TestMockedAPIs { | ||||||
|   @Test |   @Test | ||||||
|   void testMockPullModel() { |   void testPullModel() { | ||||||
|     OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); |     OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); | ||||||
|     String model = OllamaModelType.LLAMA2; |     String model = OllamaModelType.LLAMA2; | ||||||
|     try { |     try { | ||||||
| @@ -49,7 +50,7 @@ class TestMockedAPIs { | |||||||
|       doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath); |       doNothing().when(ollamaAPI).createModelWithModelFileContents(model, modelFilePath); | ||||||
|       ollamaAPI.createModelWithModelFileContents(model, modelFilePath); |       ollamaAPI.createModelWithModelFileContents(model, modelFilePath); | ||||||
|       verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath); |       verify(ollamaAPI, times(1)).createModelWithModelFileContents(model, modelFilePath); | ||||||
|     } catch (IOException | OllamaBaseException | InterruptedException e) { |     } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { | ||||||
|       throw new RuntimeException(e); |       throw new RuntimeException(e); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| @@ -62,7 +63,7 @@ class TestMockedAPIs { | |||||||
|       doNothing().when(ollamaAPI).deleteModel(model, true); |       doNothing().when(ollamaAPI).deleteModel(model, true); | ||||||
|       ollamaAPI.deleteModel(model, true); |       ollamaAPI.deleteModel(model, true); | ||||||
|       verify(ollamaAPI, times(1)).deleteModel(model, true); |       verify(ollamaAPI, times(1)).deleteModel(model, true); | ||||||
|     } catch (IOException | OllamaBaseException | InterruptedException e) { |     } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { | ||||||
|       throw new RuntimeException(e); |       throw new RuntimeException(e); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| @@ -75,7 +76,7 @@ class TestMockedAPIs { | |||||||
|       when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail()); |       when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail()); | ||||||
|       ollamaAPI.getModelDetails(model); |       ollamaAPI.getModelDetails(model); | ||||||
|       verify(ollamaAPI, times(1)).getModelDetails(model); |       verify(ollamaAPI, times(1)).getModelDetails(model); | ||||||
|     } catch (IOException | OllamaBaseException | InterruptedException e) { |     } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { | ||||||
|       throw new RuntimeException(e); |       throw new RuntimeException(e); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| @@ -108,13 +109,43 @@ class TestMockedAPIs { | |||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   @Test | ||||||
|  |   void testAskWithImageFiles() { | ||||||
|  |     OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); | ||||||
|  |     String model = OllamaModelType.LLAMA2; | ||||||
|  |     String prompt = "some prompt text"; | ||||||
|  |     try { | ||||||
|  |       when(ollamaAPI.askWithImageFiles(model, prompt, Collections.emptyList())) | ||||||
|  |           .thenReturn(new OllamaResult("", 0, 200)); | ||||||
|  |       ollamaAPI.askWithImageFiles(model, prompt, Collections.emptyList()); | ||||||
|  |       verify(ollamaAPI, times(1)).askWithImageFiles(model, prompt, Collections.emptyList()); | ||||||
|  |     } catch (IOException | OllamaBaseException | InterruptedException e) { | ||||||
|  |       throw new RuntimeException(e); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   @Test | ||||||
|  |   void testAskWithImageURLs() { | ||||||
|  |     OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); | ||||||
|  |     String model = OllamaModelType.LLAMA2; | ||||||
|  |     String prompt = "some prompt text"; | ||||||
|  |     try { | ||||||
|  |       when(ollamaAPI.askWithImageURLs(model, prompt, Collections.emptyList())) | ||||||
|  |           .thenReturn(new OllamaResult("", 0, 200)); | ||||||
|  |       ollamaAPI.askWithImageURLs(model, prompt, Collections.emptyList()); | ||||||
|  |       verify(ollamaAPI, times(1)).askWithImageURLs(model, prompt, Collections.emptyList()); | ||||||
|  |     } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { | ||||||
|  |       throw new RuntimeException(e); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|   @Test |   @Test | ||||||
|   void testAskAsync() { |   void testAskAsync() { | ||||||
|     OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); |     OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); | ||||||
|     String model = OllamaModelType.LLAMA2; |     String model = OllamaModelType.LLAMA2; | ||||||
|     String prompt = "some prompt text"; |     String prompt = "some prompt text"; | ||||||
|     when(ollamaAPI.askAsync(model, prompt)) |     when(ollamaAPI.askAsync(model, prompt)) | ||||||
|         .thenReturn(new OllamaAsyncResultCallback(null, null, null, 3)); |         .thenReturn(new OllamaAsyncResultCallback(null, null, 3)); | ||||||
|     ollamaAPI.askAsync(model, prompt); |     ollamaAPI.askAsync(model, prompt); | ||||||
|     verify(ollamaAPI, times(1)).askAsync(model, prompt); |     verify(ollamaAPI, times(1)).askAsync(model, prompt); | ||||||
|   } |   } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user