Fixes to listModels API

This commit is contained in:
Amith Koujalgi 2023-12-14 17:55:17 +05:30
parent b2d76970dc
commit 1df8622b32
3 changed files with 134 additions and 146 deletions

View File

@ -3,14 +3,19 @@ package io.github.amithkoujalgi.ollama4j.core.models;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
public class Model { public class Model {
private String name; private String name;
@JsonProperty("modified_at") @JsonProperty("modified_at")
private String modifiedAt; private String modifiedAt;
private String digest; private String digest;
private long size; private long size;
@JsonProperty("details")
private ModelMeta modelMeta;
/** /**
* Returns the model's tag. This includes model name and its version separated by a colon character `:` * Returns the model's tag. This includes model name and its version separated by a colon
* character `:`
*
* @return model tag * @return model tag
*/ */
public String getName() { public String getName() {
@ -23,6 +28,7 @@ public class Model {
/** /**
* Returns the model name without its version * Returns the model name without its version
*
* @return model name * @return model name
*/ */
public String getModelName() { public String getModelName() {
@ -31,6 +37,7 @@ public class Model {
/** /**
* Returns the model version without its name * Returns the model version without its name
*
* @return model version * @return model version
*/ */
public String getModelVersion() { public String getModelVersion() {
@ -61,4 +68,7 @@ public class Model {
this.size = size; this.size = size;
} }
public ModelMeta getModelMeta() {
return modelMeta;
}
} }

View File

@ -0,0 +1,61 @@
package io.github.amithkoujalgi.ollama4j.core.models;
import com.fasterxml.jackson.annotation.JsonProperty;
public class ModelMeta {
@JsonProperty("format")
private String format;
@JsonProperty("family")
private String family;
@JsonProperty("families")
private String[] families;
@JsonProperty("parameter_size")
private String parameterSize;
@JsonProperty("quantization_level")
private String quantizationLevel;
public String getFormat() {
return format;
}
public void setFormat(String format) {
this.format = format;
}
public String getFamily() {
return family;
}
public void setFamily(String family) {
this.family = family;
}
public String[] getFamilies() {
return families;
}
public void setFamilies(String[] families) {
this.families = families;
}
public String getParameterSize() {
return parameterSize;
}
public void setParameterSize(String parameterSize) {
this.parameterSize = parameterSize;
}
public String getQuantizationLevel() {
return quantizationLevel;
}
public void setQuantizationLevel(String quantizationLevel) {
this.quantizationLevel = quantizationLevel;
}
}

View File

@ -6,6 +6,7 @@ import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback; import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.Mockito; import org.mockito.Mockito;
@ -13,109 +14,25 @@ import java.io.IOException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.ArrayList; import java.util.ArrayList;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
public class TestRealAPIs { class TestRealAPIs {
OllamaAPI ollamaAPI;
@BeforeEach
void setUp() {
String ollamaHost = "http://localhost:11434";
ollamaAPI = new OllamaAPI(ollamaHost);
}
@Test @Test
public void testMockPullModel() { void testListModels() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try { try {
doNothing().when(ollamaAPI).pullModel(model); assertNotNull(ollamaAPI.listModels());
ollamaAPI.pullModel(model);
verify(ollamaAPI, times(1)).pullModel(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
@Test
public void testListModels() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
try {
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
ollamaAPI.listModels();
verify(ollamaAPI, times(1)).listModels();
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
@Test
public void testCreateModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String modelFilePath = "/somemodel";
try {
doNothing().when(ollamaAPI).createModel(model, modelFilePath);
ollamaAPI.createModel(model, modelFilePath);
verify(ollamaAPI, times(1)).createModel(model, modelFilePath);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
public void testDeleteModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
doNothing().when(ollamaAPI).deleteModel(model, true);
ollamaAPI.deleteModel(model, true);
verify(ollamaAPI, times(1)).deleteModel(model, true);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
public void testGetModelDetails() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
try {
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
ollamaAPI.getModelDetails(model);
verify(ollamaAPI, times(1)).getModelDetails(model);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
public void testGenerateEmbeddings() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
ollamaAPI.generateEmbeddings(model, prompt);
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
public void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
try {
when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0,200));
ollamaAPI.ask(model, prompt);
verify(ollamaAPI, times(1)).ask(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e);
}
}
@Test
public void testAskAsync() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text";
when(ollamaAPI.askAsync(model, prompt)).thenReturn(new OllamaAsyncResultCallback(null, null, null));
ollamaAPI.askAsync(model, prompt);
verify(ollamaAPI, times(1)).askAsync(model, prompt);
}
} }