mirror of
				https://github.com/amithkoujalgi/ollama4j.git
				synced 2025-11-04 02:20:50 +01:00 
			
		
		
		
	Adds options to EmbeddingsRequest
Additionally refactors the Embedding Models and Tests
This commit is contained in:
		
							
								
								
									
										12
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								README.md
									
									
									
									
									
								
							@@ -67,7 +67,7 @@ In your Maven project, add this dependency:
 | 
			
		||||
<dependency>
 | 
			
		||||
    <groupId>io.github.amithkoujalgi</groupId>
 | 
			
		||||
    <artifactId>ollama4j</artifactId>
 | 
			
		||||
    <version>1.0.47</version>
 | 
			
		||||
    <version>1.0.57</version>
 | 
			
		||||
</dependency>
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
@@ -125,15 +125,15 @@ Actions CI workflow.
 | 
			
		||||
- [x] Update request body creation with Java objects
 | 
			
		||||
- [ ] Async APIs for images
 | 
			
		||||
- [ ] Add custom headers to requests
 | 
			
		||||
- [ ] Add additional params for `ask` APIs such as:
 | 
			
		||||
- [x] Add additional params for `ask` APIs such as:
 | 
			
		||||
    - [x] `options`: additional model parameters for the Modelfile such as `temperature` -
 | 
			
		||||
      Supported [params](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
 | 
			
		||||
    - [ ] `system`: system prompt to (overrides what is defined in the Modelfile)
 | 
			
		||||
    - [ ] `template`: the full prompt or prompt template (overrides what is defined in the Modelfile)
 | 
			
		||||
    - [ ] `context`: the context parameter returned from a previous request, which can be used to keep a
 | 
			
		||||
    - [x] `system`: system prompt to (overrides what is defined in the Modelfile)
 | 
			
		||||
    - [x] `template`: the full prompt or prompt template (overrides what is defined in the Modelfile)
 | 
			
		||||
    - [x] `context`: the context parameter returned from a previous request, which can be used to keep a
 | 
			
		||||
      short
 | 
			
		||||
      conversational memory
 | 
			
		||||
    - [ ] `stream`: Add support for streaming responses from the model
 | 
			
		||||
    - [x] `stream`: Add support for streaming responses from the model
 | 
			
		||||
- [ ] Add test cases
 | 
			
		||||
- [ ] Handle exceptions better (maybe throw more appropriate exceptions)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								pom.xml
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								pom.xml
									
									
									
									
									
								
							@@ -99,7 +99,7 @@
 | 
			
		||||
                <configuration>
 | 
			
		||||
                    <skipTests>${skipUnitTests}</skipTests>
 | 
			
		||||
                    <includes>
 | 
			
		||||
                        <include>**/unittests/*.java</include>
 | 
			
		||||
                        <include>**/unittests/**/*.java</include>
 | 
			
		||||
                    </includes>
 | 
			
		||||
                </configuration>
 | 
			
		||||
            </plugin>
 | 
			
		||||
 
 | 
			
		||||
@@ -6,10 +6,11 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessage;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingResponseModel;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFileContentsRequest;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.request.CustomModelFilePathRequest;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelEmbeddingsRequest;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.request.ModelRequest;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaChatEndpointCaller;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.request.OllamaGenerateEndpointCaller;
 | 
			
		||||
@@ -313,8 +314,18 @@ public class OllamaAPI {
 | 
			
		||||
   */
 | 
			
		||||
  public List<Double> generateEmbeddings(String model, String prompt)
 | 
			
		||||
      throws IOException, InterruptedException, OllamaBaseException {
 | 
			
		||||
        return generateEmbeddings(new OllamaEmbeddingsRequestModel(model, prompt));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
   * Generate embeddings using a {@link OllamaEmbeddingsRequestModel}.
 | 
			
		||||
   *
 | 
			
		||||
   * @param modelRequest request for '/api/embeddings' endpoint
 | 
			
		||||
   * @return embeddings
 | 
			
		||||
   */
 | 
			
		||||
  public List<Double> generateEmbeddings(OllamaEmbeddingsRequestModel modelRequest) throws IOException, InterruptedException, OllamaBaseException{
 | 
			
		||||
    URI uri = URI.create(this.host + "/api/embeddings");
 | 
			
		||||
    String jsonData = new ModelEmbeddingsRequest(model, prompt).toString();
 | 
			
		||||
    String jsonData = modelRequest.toString();
 | 
			
		||||
    HttpClient httpClient = HttpClient.newHttpClient();
 | 
			
		||||
    HttpRequest.Builder requestBuilder =
 | 
			
		||||
        getRequestBuilderDefault(uri)
 | 
			
		||||
@@ -325,8 +336,8 @@ public class OllamaAPI {
 | 
			
		||||
    int statusCode = response.statusCode();
 | 
			
		||||
    String responseBody = response.body();
 | 
			
		||||
    if (statusCode == 200) {
 | 
			
		||||
      EmbeddingResponse embeddingResponse =
 | 
			
		||||
          Utils.getObjectMapper().readValue(responseBody, EmbeddingResponse.class);
 | 
			
		||||
      OllamaEmbeddingResponseModel embeddingResponse =
 | 
			
		||||
          Utils.getObjectMapper().readValue(responseBody, OllamaEmbeddingResponseModel.class);
 | 
			
		||||
      return embeddingResponse.getEmbedding();
 | 
			
		||||
    } else {
 | 
			
		||||
      throw new OllamaBaseException(statusCode + " - " + responseBody);
 | 
			
		||||
 
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.models;
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
			
		||||
 | 
			
		||||
@@ -7,7 +7,7 @@ import lombok.Data;
 | 
			
		||||
 | 
			
		||||
@SuppressWarnings("unused")
 | 
			
		||||
@Data
 | 
			
		||||
public class EmbeddingResponse {
 | 
			
		||||
public class OllamaEmbeddingResponseModel {
 | 
			
		||||
    @JsonProperty("embedding")
 | 
			
		||||
    private List<Double> embedding;
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,31 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
 | 
			
		||||
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.Options;
 | 
			
		||||
 | 
			
		||||
public class OllamaEmbeddingsRequestBuilder {
 | 
			
		||||
 | 
			
		||||
    private OllamaEmbeddingsRequestBuilder(String model, String prompt){
 | 
			
		||||
        request = new OllamaEmbeddingsRequestModel(model, prompt);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private OllamaEmbeddingsRequestModel request;
 | 
			
		||||
 | 
			
		||||
    public static OllamaEmbeddingsRequestBuilder getInstance(String model, String prompt){
 | 
			
		||||
        return new OllamaEmbeddingsRequestBuilder(model, prompt);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public OllamaEmbeddingsRequestModel build(){
 | 
			
		||||
        return request;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public OllamaEmbeddingsRequestBuilder withOptions(Options options){
 | 
			
		||||
        this.request.setOptions(options.getOptionsMap());
 | 
			
		||||
        return this;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public OllamaEmbeddingsRequestBuilder withKeepAlive(String keepAlive){
 | 
			
		||||
        this.request.setKeepAlive(keepAlive);
 | 
			
		||||
        return this;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,33 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.models.embeddings;
 | 
			
		||||
 | 
			
		||||
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import lombok.RequiredArgsConstructor;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@RequiredArgsConstructor
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class OllamaEmbeddingsRequestModel {
 | 
			
		||||
  @NonNull
 | 
			
		||||
  private String model;
 | 
			
		||||
  @NonNull
 | 
			
		||||
  private String prompt;
 | 
			
		||||
 | 
			
		||||
  protected Map<String, Object> options;
 | 
			
		||||
  @JsonProperty(value = "keep_alive")
 | 
			
		||||
  private String keepAlive;
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
  public String toString() {
 | 
			
		||||
    try {
 | 
			
		||||
      return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
 | 
			
		||||
    } catch (JsonProcessingException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@@ -1,23 +0,0 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.core.models.request;
 | 
			
		||||
 | 
			
		||||
import static io.github.amithkoujalgi.ollama4j.core.utils.Utils.getObjectMapper;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
public class ModelEmbeddingsRequest {
 | 
			
		||||
  private String model;
 | 
			
		||||
  private String prompt;
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
  public String toString() {
 | 
			
		||||
    try {
 | 
			
		||||
      return getObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this);
 | 
			
		||||
    } catch (JsonProcessingException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@@ -10,6 +10,8 @@ import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatResult;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
 | 
			
		||||
import java.io.File;
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
@@ -61,7 +63,7 @@ class TestRealAPIs {
 | 
			
		||||
    } catch (HttpConnectTimeoutException e) {
 | 
			
		||||
      fail(e.getMessage());
 | 
			
		||||
    } catch (Exception e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -73,7 +75,7 @@ class TestRealAPIs {
 | 
			
		||||
      assertNotNull(ollamaAPI.listModels());
 | 
			
		||||
      ollamaAPI.listModels().forEach(System.out::println);
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -88,7 +90,7 @@ class TestRealAPIs {
 | 
			
		||||
              .anyMatch(model -> model.getModel().equalsIgnoreCase(config.getModel()));
 | 
			
		||||
      assertTrue(found);
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -101,7 +103,7 @@ class TestRealAPIs {
 | 
			
		||||
      assertNotNull(modelDetails);
 | 
			
		||||
      System.out.println(modelDetails);
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -119,7 +121,7 @@ class TestRealAPIs {
 | 
			
		||||
      assertNotNull(result.getResponse());
 | 
			
		||||
      assertFalse(result.getResponse().isEmpty());
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -145,7 +147,7 @@ class TestRealAPIs {
 | 
			
		||||
      assertFalse(result.getResponse().isEmpty());
 | 
			
		||||
      assertEquals(sb.toString().trim(), result.getResponse().trim());
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -163,7 +165,7 @@ class TestRealAPIs {
 | 
			
		||||
      assertNotNull(result.getResponse());
 | 
			
		||||
      assertFalse(result.getResponse().isEmpty());
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -183,7 +185,7 @@ class TestRealAPIs {
 | 
			
		||||
      assertFalse(chatResult.getResponse().isBlank());
 | 
			
		||||
      assertEquals(4,chatResult.getChatHistory().size());
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -205,7 +207,7 @@ class TestRealAPIs {
 | 
			
		||||
      assertTrue(chatResult.getResponse().startsWith("NI"));
 | 
			
		||||
      assertEquals(3, chatResult.getChatHistory().size());
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -230,7 +232,7 @@ class TestRealAPIs {
 | 
			
		||||
      assertNotNull(chatResult);
 | 
			
		||||
      assertEquals(sb.toString().trim(), chatResult.getResponse().trim());
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -261,7 +263,7 @@ class TestRealAPIs {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -278,7 +280,7 @@ class TestRealAPIs {
 | 
			
		||||
      OllamaChatResult chatResult = ollamaAPI.chat(requestModel);
 | 
			
		||||
      assertNotNull(chatResult);
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -298,7 +300,7 @@ class TestRealAPIs {
 | 
			
		||||
      assertNotNull(result.getResponse());
 | 
			
		||||
      assertFalse(result.getResponse().isEmpty());
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -322,7 +324,7 @@ class TestRealAPIs {
 | 
			
		||||
      assertFalse(result.getResponse().isEmpty());
 | 
			
		||||
      assertEquals(sb.toString().trim(), result.getResponse().trim());
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -342,7 +344,24 @@ class TestRealAPIs {
 | 
			
		||||
      assertNotNull(result.getResponse());
 | 
			
		||||
      assertFalse(result.getResponse().isEmpty());
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
 | 
			
		||||
      throw new RuntimeException(e);
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @Test
 | 
			
		||||
  @Order(3)
 | 
			
		||||
  public void testEmbedding() {
 | 
			
		||||
    testEndpointReachability();
 | 
			
		||||
    try {
 | 
			
		||||
      OllamaEmbeddingsRequestModel request = OllamaEmbeddingsRequestBuilder
 | 
			
		||||
          .getInstance(config.getModel(), "What is the capital of France?").build();
 | 
			
		||||
 | 
			
		||||
      List<Double> embeddings = ollamaAPI.generateEmbeddings(request);
 | 
			
		||||
 | 
			
		||||
      assertNotNull(embeddings);
 | 
			
		||||
      assertFalse(embeddings.isEmpty());
 | 
			
		||||
    } catch (IOException | OllamaBaseException | InterruptedException e) {
 | 
			
		||||
      fail(e);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,35 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
 | 
			
		||||
 | 
			
		||||
import static org.junit.jupiter.api.Assertions.assertEquals;
 | 
			
		||||
import static org.junit.jupiter.api.Assertions.fail;
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import com.fasterxml.jackson.databind.ObjectMapper;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
 | 
			
		||||
 | 
			
		||||
public abstract class AbstractRequestSerializationTest<T> {
 | 
			
		||||
 | 
			
		||||
    protected ObjectMapper mapper = Utils.getObjectMapper();
 | 
			
		||||
 | 
			
		||||
    protected String serializeRequest(T req) {
 | 
			
		||||
        try {
 | 
			
		||||
            return mapper.writeValueAsString(req);
 | 
			
		||||
        } catch (JsonProcessingException e) {
 | 
			
		||||
            fail("Could not serialize request!", e);
 | 
			
		||||
            return null;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    protected T deserializeRequest(String jsonRequest, Class<T> requestClass) {
 | 
			
		||||
        try {
 | 
			
		||||
            return mapper.readValue(jsonRequest, requestClass);
 | 
			
		||||
        } catch (JsonProcessingException e) {
 | 
			
		||||
            fail("Could not deserialize jsonRequest!", e);
 | 
			
		||||
            return null;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    protected void assertEqualsAfterUnmarshalling(T unmarshalledRequest,
 | 
			
		||||
        T req) {
 | 
			
		||||
        assertEquals(req, unmarshalledRequest);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
 | 
			
		||||
 | 
			
		||||
import static org.junit.jupiter.api.Assertions.assertEquals;
 | 
			
		||||
import static org.junit.jupiter.api.Assertions.fail;
 | 
			
		||||
 | 
			
		||||
import java.io.File;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
@@ -10,21 +9,15 @@ import org.json.JSONObject;
 | 
			
		||||
import org.junit.jupiter.api.BeforeEach;
 | 
			
		||||
import org.junit.jupiter.api.Test;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import com.fasterxml.jackson.databind.ObjectMapper;
 | 
			
		||||
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatMessageRole;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestBuilder;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.chat.OllamaChatRequestModel;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
 | 
			
		||||
 | 
			
		||||
public class TestChatRequestSerialization {
 | 
			
		||||
public class TestChatRequestSerialization extends AbstractRequestSerializationTest<OllamaChatRequestModel>{
 | 
			
		||||
 | 
			
		||||
    private OllamaChatRequestBuilder builder;
 | 
			
		||||
 | 
			
		||||
    private ObjectMapper mapper = Utils.getObjectMapper();
 | 
			
		||||
 | 
			
		||||
    @BeforeEach
 | 
			
		||||
    public void init() {
 | 
			
		||||
        builder = OllamaChatRequestBuilder.getInstance("DummyModel");
 | 
			
		||||
@@ -32,10 +25,9 @@ public class TestChatRequestSerialization {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testRequestOnlyMandatoryFields() {
 | 
			
		||||
        OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
 | 
			
		||||
                List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
 | 
			
		||||
        OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt").build();
 | 
			
		||||
        String jsonRequest = serializeRequest(req);
 | 
			
		||||
        assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
 | 
			
		||||
        assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
@@ -44,7 +36,7 @@ public class TestChatRequestSerialization {
 | 
			
		||||
        .withMessage(OllamaChatMessageRole.USER, "Some prompt")
 | 
			
		||||
        .build();
 | 
			
		||||
        String jsonRequest = serializeRequest(req);
 | 
			
		||||
        assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
 | 
			
		||||
        assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
@@ -52,7 +44,7 @@ public class TestChatRequestSerialization {
 | 
			
		||||
        OllamaChatRequestModel req = builder.withMessage(OllamaChatMessageRole.USER, "Some prompt",
 | 
			
		||||
                List.of(new File("src/test/resources/dog-on-a-boat.jpg"))).build();
 | 
			
		||||
        String jsonRequest = serializeRequest(req);
 | 
			
		||||
        assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
 | 
			
		||||
        assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaChatRequestModel.class), req);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
@@ -62,7 +54,7 @@ public class TestChatRequestSerialization {
 | 
			
		||||
                .withOptions(b.setMirostat(1).build()).build();
 | 
			
		||||
 | 
			
		||||
        String jsonRequest = serializeRequest(req);
 | 
			
		||||
        OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest);
 | 
			
		||||
        OllamaChatRequestModel deserializeRequest = deserializeRequest(jsonRequest,OllamaChatRequestModel.class);
 | 
			
		||||
        assertEqualsAfterUnmarshalling(deserializeRequest, req);
 | 
			
		||||
        assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
 | 
			
		||||
    }
 | 
			
		||||
@@ -79,28 +71,4 @@ public class TestChatRequestSerialization {
 | 
			
		||||
        String requestFormatProperty = jsonObject.getString("format");
 | 
			
		||||
        assertEquals("json", requestFormatProperty);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private String serializeRequest(OllamaChatRequestModel req) {
 | 
			
		||||
        try {
 | 
			
		||||
            return mapper.writeValueAsString(req);
 | 
			
		||||
        } catch (JsonProcessingException e) {
 | 
			
		||||
            fail("Could not serialize request!", e);
 | 
			
		||||
            return null;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private OllamaChatRequestModel deserializeRequest(String jsonRequest) {
 | 
			
		||||
        try {
 | 
			
		||||
            return mapper.readValue(jsonRequest, OllamaChatRequestModel.class);
 | 
			
		||||
        } catch (JsonProcessingException e) {
 | 
			
		||||
            fail("Could not deserialize jsonRequest!", e);
 | 
			
		||||
            return null;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private void assertEqualsAfterUnmarshalling(OllamaChatRequestModel unmarshalledRequest,
 | 
			
		||||
            OllamaChatRequestModel req) {
 | 
			
		||||
        assertEquals(req, unmarshalledRequest);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,37 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
 | 
			
		||||
 | 
			
		||||
import static org.junit.jupiter.api.Assertions.assertEquals;
 | 
			
		||||
import org.junit.jupiter.api.BeforeEach;
 | 
			
		||||
import org.junit.jupiter.api.Test;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestModel;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.embeddings.OllamaEmbeddingsRequestBuilder;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
 | 
			
		||||
 | 
			
		||||
public class TestEmbeddingsRequestSerialization extends AbstractRequestSerializationTest<OllamaEmbeddingsRequestModel>{
 | 
			
		||||
 | 
			
		||||
        private OllamaEmbeddingsRequestBuilder builder;
 | 
			
		||||
 | 
			
		||||
        @BeforeEach
 | 
			
		||||
        public void init() {
 | 
			
		||||
            builder = OllamaEmbeddingsRequestBuilder.getInstance("DummyModel","DummyPrompt");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
            @Test
 | 
			
		||||
    public void testRequestOnlyMandatoryFields() {
 | 
			
		||||
        OllamaEmbeddingsRequestModel req = builder.build();
 | 
			
		||||
        String jsonRequest = serializeRequest(req);
 | 
			
		||||
        assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest,OllamaEmbeddingsRequestModel.class), req);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
        @Test
 | 
			
		||||
        public void testRequestWithOptions() {
 | 
			
		||||
            OptionsBuilder b = new OptionsBuilder();
 | 
			
		||||
            OllamaEmbeddingsRequestModel req = builder
 | 
			
		||||
                    .withOptions(b.setMirostat(1).build()).build();
 | 
			
		||||
 | 
			
		||||
            String jsonRequest = serializeRequest(req);
 | 
			
		||||
            OllamaEmbeddingsRequestModel deserializeRequest = deserializeRequest(jsonRequest,OllamaEmbeddingsRequestModel.class);
 | 
			
		||||
            assertEqualsAfterUnmarshalling(deserializeRequest, req);
 | 
			
		||||
            assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
 | 
			
		||||
        }
 | 
			
		||||
}
 | 
			
		||||
@@ -1,26 +1,20 @@
 | 
			
		||||
package io.github.amithkoujalgi.ollama4j.unittests.jackson;
 | 
			
		||||
 | 
			
		||||
import static org.junit.jupiter.api.Assertions.assertEquals;
 | 
			
		||||
import static org.junit.jupiter.api.Assertions.fail;
 | 
			
		||||
 | 
			
		||||
import org.json.JSONObject;
 | 
			
		||||
import org.junit.jupiter.api.BeforeEach;
 | 
			
		||||
import org.junit.jupiter.api.Test;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.core.JsonProcessingException;
 | 
			
		||||
import com.fasterxml.jackson.databind.ObjectMapper;
 | 
			
		||||
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestBuilder;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.models.generate.OllamaGenerateRequestModel;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.OptionsBuilder;
 | 
			
		||||
import io.github.amithkoujalgi.ollama4j.core.utils.Utils;
 | 
			
		||||
 | 
			
		||||
public class TestGenerateRequestSerialization {
 | 
			
		||||
public class TestGenerateRequestSerialization extends AbstractRequestSerializationTest<OllamaGenerateRequestModel>{
 | 
			
		||||
 | 
			
		||||
    private OllamaGenerateRequestBuilder builder;
 | 
			
		||||
 | 
			
		||||
    private ObjectMapper mapper = Utils.getObjectMapper();
 | 
			
		||||
 | 
			
		||||
    @BeforeEach
 | 
			
		||||
    public void init() {
 | 
			
		||||
        builder = OllamaGenerateRequestBuilder.getInstance("DummyModel");
 | 
			
		||||
@@ -31,7 +25,7 @@ public class TestGenerateRequestSerialization {
 | 
			
		||||
        OllamaGenerateRequestModel req = builder.withPrompt("Some prompt").build();
 | 
			
		||||
 | 
			
		||||
        String jsonRequest = serializeRequest(req);
 | 
			
		||||
        assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest), req);
 | 
			
		||||
        assertEqualsAfterUnmarshalling(deserializeRequest(jsonRequest, OllamaGenerateRequestModel.class), req);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
@@ -41,7 +35,7 @@ public class TestGenerateRequestSerialization {
 | 
			
		||||
                builder.withPrompt("Some prompt").withOptions(b.setMirostat(1).build()).build();
 | 
			
		||||
 | 
			
		||||
        String jsonRequest = serializeRequest(req);
 | 
			
		||||
        OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest);
 | 
			
		||||
        OllamaGenerateRequestModel deserializeRequest = deserializeRequest(jsonRequest, OllamaGenerateRequestModel.class);
 | 
			
		||||
        assertEqualsAfterUnmarshalling(deserializeRequest, req);
 | 
			
		||||
        assertEquals(1, deserializeRequest.getOptions().get("mirostat"));
 | 
			
		||||
    }
 | 
			
		||||
@@ -59,27 +53,4 @@ public class TestGenerateRequestSerialization {
 | 
			
		||||
        assertEquals("json", requestFormatProperty);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private String serializeRequest(OllamaGenerateRequestModel req) {
 | 
			
		||||
        try {
 | 
			
		||||
            return mapper.writeValueAsString(req);
 | 
			
		||||
        } catch (JsonProcessingException e) {
 | 
			
		||||
            fail("Could not serialize request!", e);
 | 
			
		||||
            return null;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private OllamaGenerateRequestModel deserializeRequest(String jsonRequest) {
 | 
			
		||||
        try {
 | 
			
		||||
            return mapper.readValue(jsonRequest, OllamaGenerateRequestModel.class);
 | 
			
		||||
        } catch (JsonProcessingException e) {
 | 
			
		||||
            fail("Could not deserialize jsonRequest!", e);
 | 
			
		||||
            return null;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private void assertEqualsAfterUnmarshalling(OllamaGenerateRequestModel unmarshalledRequest,
 | 
			
		||||
            OllamaGenerateRequestModel req) {
 | 
			
		||||
        assertEquals(req, unmarshalledRequest);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user