Fixes to listModels API

This commit is contained in:
Amith Koujalgi 2023-12-14 17:55:17 +05:30
parent b2d76970dc
commit 1df8622b32
3 changed files with 134 additions and 146 deletions

View File

@ -3,14 +3,19 @@ package io.github.amithkoujalgi.ollama4j.core.models;
import com.fasterxml.jackson.annotation.JsonProperty;
public class Model {
private String name;
@JsonProperty("modified_at")
private String modifiedAt;
private String digest;
private long size;
@JsonProperty("details")
private ModelMeta modelMeta;
/**
* Returns the model's tag. This includes model name and its version separated by a colon character `:`
* Returns the model's tag. This includes model name and its version separated by a colon
* character `:`
*
* @return model tag
*/
public String getName() {
@ -23,6 +28,7 @@ public class Model {
/**
* Returns the model name without its version
*
* @return model name
*/
public String getModelName() {
@ -31,6 +37,7 @@ public class Model {
/**
* Returns the model version without its name
*
* @return model version
*/
public String getModelVersion() {
@ -61,4 +68,7 @@ public class Model {
this.size = size;
}
public ModelMeta getModelMeta() {
return modelMeta;
}
}

View File

@ -0,0 +1,61 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import com.fasterxml.jackson.annotation.JsonProperty;
public class ModelMeta {
@JsonProperty("format")
private String format;
@JsonProperty("family")
private String family;
@JsonProperty("families")
private String[] families;
@JsonProperty("parameter_size")
private String parameterSize;
@JsonProperty("quantization_level")
private String quantizationLevel;
public String getFormat() {
return format;
}
public void setFormat(String format) {
this.format = format;
}
public String getFamily() {
return family;
}
public void setFamily(String family) {
this.family = family;
}
public String[] getFamilies() {
return families;
}
public void setFamilies(String[] families) {
this.families = families;
}
public String getParameterSize() {
return parameterSize;
}
public void setParameterSize(String parameterSize) {
this.parameterSize = parameterSize;
}
public String getQuantizationLevel() {
return quantizationLevel;
}
public void setQuantizationLevel(String quantizationLevel) {
this.quantizationLevel = quantizationLevel;
}
}

View File

@ -6,6 +6,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
@ -13,109 +14,25 @@ import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.*;
public class TestRealAPIs {
class TestRealAPIs {
OllamaAPI ollamaAPI;
@BeforeEach
void setUp() {
String ollamaHost = "http://localhost:11434";
ollamaAPI = new OllamaAPI(ollamaHost);
}
@Test
public void testMockPullModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
void testListModels() {
try {
doNothing().when(ollamaAPI).pullModel(model);
ollamaAPI.pullModel(model);
verify(ollamaAPI, times(1)).pullModel(model);
assertNotNull(ollamaAPI.listModels());
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
@Test
public void testListModels() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
try {
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
ollamaAPI.listModels();
verify(ollamaAPI, times(1)).listModels();
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
@Test
public void testCreateModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String modelFilePath = "/somemodel";
try {
doNothing().when(ollamaAPI).createModel(model, modelFilePath);
ollamaAPI.createModel(model, modelFilePath);
verify(ollamaAPI, times(1)).createModel(model, modelFilePath);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
public void testDeleteModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
doNothing().when(ollamaAPI).deleteModel(model, true);
ollamaAPI.deleteModel(model, true);
verify(ollamaAPI, times(1)).deleteModel(model, true);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
public void testGetModelDetails() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
ollamaAPI.getModelDetails(model);
verify(ollamaAPI, times(1)).getModelDetails(model);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
public void testGenerateEmbeddings() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
ollamaAPI.generateEmbeddings(model, prompt);
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
public void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0,200));
ollamaAPI.ask(model, prompt);
verify(ollamaAPI, times(1)).ask(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
public void testAskAsync() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
when(ollamaAPI.askAsync(model, prompt)).thenReturn(new OllamaAsyncResultCallback(null, null, null));
ollamaAPI.askAsync(model, prompt);
verify(ollamaAPI, times(1)).askAsync(model, prompt);
}
}