Fixes to tests

This commit is contained in:
Amith Koujalgi 2023-12-14 16:47:38 +05:30
parent 4e4a5d2996
commit d52427fb68

View File

@ -1,121 +1,121 @@
package io.github.amithkoujalgi.ollama4j; package io.github.amithkoujalgi.ollama4j;
import static org.mockito.Mockito.*;
import io.github.amithkoujalgi.ollama4j.core.OllamaAPI; import io.github.amithkoujalgi.ollama4j.core.OllamaAPI;
import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException; import io.github.amithkoujalgi.ollama4j.core.exceptions.OllamaBaseException;
import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail; import io.github.amithkoujalgi.ollama4j.core.models.ModelDetail;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback; import io.github.amithkoujalgi.ollama4j.core.models.OllamaAsyncResultCallback;
import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult; import io.github.amithkoujalgi.ollama4j.core.models.OllamaResult;
import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType; import io.github.amithkoujalgi.ollama4j.core.types.OllamaModelType;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import java.io.IOException; import java.io.IOException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.ArrayList; import java.util.ArrayList;
import org.junit.jupiter.api.Test;
import static org.mockito.Mockito.*; import org.mockito.Mockito;
public class TestMockedAPIs { public class TestMockedAPIs {
@Test @Test
public void testMockPullModel() { public void testMockPullModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
try { try {
doNothing().when(ollamaAPI).pullModel(model); doNothing().when(ollamaAPI).pullModel(model);
ollamaAPI.pullModel(model); ollamaAPI.pullModel(model);
verify(ollamaAPI, times(1)).pullModel(model); verify(ollamaAPI, times(1)).pullModel(model);
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testListModels() { public void testListModels() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
try { try {
when(ollamaAPI.listModels()).thenReturn(new ArrayList<>()); when(ollamaAPI.listModels()).thenReturn(new ArrayList<>());
ollamaAPI.listModels(); ollamaAPI.listModels();
verify(ollamaAPI, times(1)).listModels(); verify(ollamaAPI, times(1)).listModels();
} catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) { } catch (IOException | OllamaBaseException | InterruptedException | URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testCreateModel() { public void testCreateModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String modelFilePath = "/somemodel"; String modelFilePath = "/somemodel";
try { try {
doNothing().when(ollamaAPI).createModel(model, modelFilePath); doNothing().when(ollamaAPI).createModel(model, modelFilePath);
ollamaAPI.createModel(model, modelFilePath); ollamaAPI.createModel(model, modelFilePath);
verify(ollamaAPI, times(1)).createModel(model, modelFilePath); verify(ollamaAPI, times(1)).createModel(model, modelFilePath);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testDeleteModel() { public void testDeleteModel() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
try { try {
doNothing().when(ollamaAPI).deleteModel(model, true); doNothing().when(ollamaAPI).deleteModel(model, true);
ollamaAPI.deleteModel(model, true); ollamaAPI.deleteModel(model, true);
verify(ollamaAPI, times(1)).deleteModel(model, true); verify(ollamaAPI, times(1)).deleteModel(model, true);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testGetModelDetails() { public void testGetModelDetails() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
try { try {
when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail()); when(ollamaAPI.getModelDetails(model)).thenReturn(new ModelDetail());
ollamaAPI.getModelDetails(model); ollamaAPI.getModelDetails(model);
verify(ollamaAPI, times(1)).getModelDetails(model); verify(ollamaAPI, times(1)).getModelDetails(model);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testGenerateEmbeddings() { public void testGenerateEmbeddings() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
try { try {
when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>()); when(ollamaAPI.generateEmbeddings(model, prompt)).thenReturn(new ArrayList<>());
ollamaAPI.generateEmbeddings(model, prompt); ollamaAPI.generateEmbeddings(model, prompt);
verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt); verify(ollamaAPI, times(1)).generateEmbeddings(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testAsk() { public void testAsk() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
try { try {
when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0)); when(ollamaAPI.ask(model, prompt)).thenReturn(new OllamaResult("", 0, 200));
ollamaAPI.ask(model, prompt); ollamaAPI.ask(model, prompt);
verify(ollamaAPI, times(1)).ask(model, prompt); verify(ollamaAPI, times(1)).ask(model, prompt);
} catch (IOException | OllamaBaseException | InterruptedException e) { } catch (IOException | OllamaBaseException | InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}
} }
}
@Test @Test
public void testAskAsync() { public void testAskAsync() {
OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class); OllamaAPI ollamaAPI = Mockito.mock(OllamaAPI.class);
String model = OllamaModelType.LLAMA2; String model = OllamaModelType.LLAMA2;
String prompt = "some prompt text"; String prompt = "some prompt text";
when(ollamaAPI.askAsync(model, prompt)).thenReturn(new OllamaAsyncResultCallback(null, null, null)); when(ollamaAPI.askAsync(model, prompt))
ollamaAPI.askAsync(model, prompt); .thenReturn(new OllamaAsyncResultCallback(null, null, null));
verify(ollamaAPI, times(1)).askAsync(model, prompt); ollamaAPI.askAsync(model, prompt);
} verify(ollamaAPI, times(1)).askAsync(model, prompt);
}
} }