diff --git a/.gitignore b/.gitignore index d9058fca..61dc4d8d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,10 @@ env .env +.env310 .vscode test_wrapper.py test_logging.py +/ignore # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/tests/test_api_wrapper/test_ctl_commands.py b/tests/test_api_wrapper/test_ctl_commands.py new file mode 100644 index 00000000..6c003c82 --- /dev/null +++ b/tests/test_api_wrapper/test_ctl_commands.py @@ -0,0 +1,124 @@ +''' Tests for ctl_commands.py ''' + +import unittest + +from unittest.mock import patch + +from runpod.api_wrapper import ctl_commands + +class TestCTL(unittest.TestCase): + ''' Tests for CTL Commands ''' + + def test_get_gpus(self): + ''' + Tests get_gpus + ''' + with patch("runpod.api_wrapper.graphql.requests.post") as patch_request: + patch_request.return_value.json.return_value = { + "data": { + "gpuTypes": [ + { + "id": "NVIDIA A100 80GB PCIe", + "displayName": "A100 80GB", + "memoryInGb": 80 + } + ] + } + } + + gpus = ctl_commands.get_gpus() + + self.assertEqual(len(gpus), 1) + self.assertEqual(gpus[0]["id"], "NVIDIA A100 80GB PCIe") + + def test_get_gpu(self): + ''' + Tests get_gpu_by_id + ''' + with patch("runpod.api_wrapper.graphql.requests.post") as patch_request: + patch_request.return_value.json.return_value = { + "data": { + "gpuTypes": [ + { + "id": "NVIDIA A100 80GB PCIe", + "displayName": "A100 80GB", + "memoryInGb": 80 + } + ] + } + } + + gpu = ctl_commands.get_gpu("NVIDIA A100 80GB PCIe") + + self.assertEqual(gpu["id"], "NVIDIA A100 80GB PCIe") + + def test_create_pod(self): + ''' + Tests create_pod + ''' + with patch("runpod.api_wrapper.graphql.requests.post") as patch_request: + patch_request.return_value.json.return_value = { + "data": { + "podFindAndDeployOnDemand": { + "id": "POD_ID" + } + } + } + + pod = ctl_commands.create_pod( + name="POD_NAME", + image_name="IMAGE_NAME", + gpu_type_id="NVIDIA A100 80GB PCIe") + + self.assertEqual(pod["id"], "POD_ID") + + + def test_stop_pod(self): + ''' + Test stop_pod + ''' + with patch("runpod.api_wrapper.graphql.requests.post") as patch_request: + patch_request.return_value.json.return_value = { + "data": { + "podStop": { + "id": "POD_ID" + } + } + } + + pod = ctl_commands.stop_pod( + pod_id="POD_ID") + + self.assertEqual(pod["id"], "POD_ID") + + def test_resume_pod(self): + ''' + Test resume_pod + ''' + with patch("runpod.api_wrapper.graphql.requests.post") as patch_request: + patch_request.return_value.json.return_value = { + "data": { + "podResume": { + "id": "POD_ID" + } + } + } + + pod = ctl_commands.resume_pod(pod_id="POD_ID", gpu_count=1) + + self.assertEqual(pod["id"], "POD_ID") + + def test_terminate_pod(self): + ''' + Test terminate_pod + ''' + with patch("runpod.api_wrapper.graphql.requests.post") as patch_request: + patch_request.return_value.json.return_value = { + "data": { + "podTerminate": { + "id": "POD_ID" + } + } + } + + self.assertIsNone(ctl_commands.terminate_pod(pod_id="POD_ID")) diff --git a/tests/test_api_wrapper/test_mutations_pods.py b/tests/test_api_wrapper/test_mutations_pods.py new file mode 100644 index 00000000..49136808 --- /dev/null +++ b/tests/test_api_wrapper/test_mutations_pods.py @@ -0,0 +1,60 @@ +''' Test API Wrapper Pod Mutations ''' + +import unittest + +from runpod.api_wrapper.mutations import pods + +class TestPodMutations(unittest.TestCase): + ''' Test API Wrapper Pod Mutations ''' + + def test_generate_pod_deployment_mutation(self): + ''' + Test generate_pod_deployment_mutation + ''' + result = pods.generate_pod_deployment_mutation( + name="test", + image_name="test_image", + gpu_type_id="1", + cloud_type="cloud", + data_center_id="1", + country_code="US", + gpu_count=1, + volume_in_gb=100, + container_disk_in_gb=10, + min_vcpu_count=1, + min_memory_in_gb=1, + docker_args="args", + ports="8080", + volume_mount_path="/path", + env={"ENV": "test"}, + support_public_ip=True) + + # Here you should check the correct structure of the result + self.assertIn("mutation", result) + + def test_generate_pod_stop_mutation(self): + ''' + Test generate_pod_stop_mutation + ''' + result = pods.generate_pod_stop_mutation("pod_id") + # Here you should check the correct structure of the result + self.assertIn("mutation", result) + + def test_generate_pod_resume_mutation(self): + ''' + Test generate_pod_resume_mutation + ''' + result = pods.generate_pod_resume_mutation("pod_id", 1) + # Here you should check the correct structure of the result + self.assertIn("mutation", result) + + def test_generate_pod_terminate_mutation(self): + ''' + Test generate_pod_terminate_mutation + ''' + result = pods.generate_pod_terminate_mutation("pod_id") + # Here you should check the correct structure of the result + self.assertIn("mutation", result) + +if __name__ == "__main__": + unittest.main()