1515import time
1616import unittest
1717from contextlib import contextmanager
18+ from dataclasses import dataclass
1819from datetime import datetime
1920from os .path import join
2021from typing import Callable , Generator , Optional
@@ -828,6 +829,62 @@ def test_close_twice(self) -> None:
828829 self .scheduler .close ()
829830 # nothing to validate just make sure no errors are raised
830831
832+ def test_get_gpu_count (self ) -> None :
833+ @dataclass
834+ class ProcResult :
835+ stdout : bytes
836+
837+ nvidia_smi_out = (
838+ "GPU 0: Tesla V100-SXM2-16GB (UUID: GPU-196a22c5-717b-66db-0acc-58cde6f3df85)\n "
839+ "GPU 1: Tesla V100-SXM2-16GB (UUID: GPU-45e9165d-4f7e-d954-7ff5-481bc2c0ec7b)\n "
840+ "GPU 2: Tesla V100-SXM2-16GB (UUID: GPU-26e22503-5fd5-8f55-d068-e1714fbb6fd6)\n "
841+ "GPU 3: Tesla V100-SXM2-16GB (UUID: GPU-ebfc20c7-5f1a-1bc9-0d98-601cbe21fc2d)\n "
842+ )
843+
844+ stdout = nvidia_smi_out .encode ()
845+ result = ProcResult (stdout )
846+ with patch ("subprocess.run" , return_value = result ):
847+ gpu_count = self .scheduler ._get_gpu_count ()
848+ self .assertEqual (4 , gpu_count )
849+
850+ def test_get_gpu_count_error (self ) -> None :
851+ with patch ("subprocess.run" , side_effect = Exception ("test error" )):
852+ gpu_count = self .scheduler ._get_gpu_count ()
853+ self .assertEqual (0 , gpu_count )
854+
855+ def test_get_cuda_devices (self ) -> None :
856+ with patch .object (self .scheduler , "_get_gpu_count" , return_value = 8 ):
857+ self .assertEqual ("0,1,2,3" , self .scheduler ._get_cuda_devices (0 , 2 ))
858+ self .assertEqual ("4,5,6,7" , self .scheduler ._get_cuda_devices (1 , 2 ))
859+ with patch .object (self .scheduler , "_get_gpu_count" , return_value = 4 ):
860+ self .assertEqual ("0" , self .scheduler ._get_cuda_devices (0 , 4 ))
861+ self .assertEqual ("1" , self .scheduler ._get_cuda_devices (1 , 4 ))
862+ self .assertEqual ("2" , self .scheduler ._get_cuda_devices (2 , 4 ))
863+ self .assertEqual ("3" , self .scheduler ._get_cuda_devices (3 , 4 ))
864+
865+ def test_get_cuda_devices_is_set (self ) -> None :
866+ with patch .object (self .scheduler , "_get_gpu_count" , return_value = 8 ):
867+ sleep_60sec = AppDef (
868+ name = "sleep" ,
869+ roles = [
870+ Role (
871+ name = "sleep" ,
872+ image = self .test_dir ,
873+ entrypoint = "sleep.sh" ,
874+ args = ["60" ],
875+ num_replicas = 4 ,
876+ )
877+ ],
878+ )
879+
880+ popen_req = self .scheduler ._to_popen_request (sleep_60sec , {})
881+ role_params = popen_req .role_params ["sleep" ]
882+ self .assertEqual (4 , len (role_params ))
883+ self .assertEqual ("0,1" , role_params [0 ].env ["CUDA_VISIBLE_DEVICES" ])
884+ self .assertEqual ("2,3" , role_params [1 ].env ["CUDA_VISIBLE_DEVICES" ])
885+ self .assertEqual ("4,5" , role_params [2 ].env ["CUDA_VISIBLE_DEVICES" ])
886+ self .assertEqual ("6,7" , role_params [3 ].env ["CUDA_VISIBLE_DEVICES" ])
887+
831888 def test_no_orphan_process_function (self ) -> None :
832889 self ._test_orphan_workflow ()
833890
@@ -839,6 +896,9 @@ def _test_orphan_workflow(self) -> None:
839896 target = start_sleep_processes , args = (self .test_dir , mp_queue , child_nproc )
840897 )
841898 proc .start ()
899+ # Before querying the queue we need to wait
900+ # Otherwise we will get `FileNotFoundError: [Errno 2] No such file or directory` error
901+ time .sleep (10 )
842902 total_processes = child_nproc + 1
843903 pids = []
844904 for _ in range (total_processes ):
0 commit comments