@@ -853,18 +853,12 @@ def test_validate_smdataparallel_args_raises():
853853 smdataparallel_enabled = {"smdistributed" : {"dataparallel" : {"enabled" : True }}}
854854
855855 # Cases {PT|TF2}
856- # 1. None instance type
857- # 2. incorrect instance type
858- # 3. incorrect python version
859- # 4. incorrect framework version
856+ # 1. incorrect python version
857+ # 2. incorrect framework version
860858
861859 bad_args = [
862- (None , "tensorflow" , "2.3.1" , "py3" , smdataparallel_enabled ),
863- ("ml.p3.2xlarge" , "tensorflow" , "2.3.1" , "py3" , smdataparallel_enabled ),
864860 ("ml.p3dn.24xlarge" , "tensorflow" , "2.3.1" , "py2" , smdataparallel_enabled ),
865861 ("ml.p3.16xlarge" , "tensorflow" , "1.3.1" , "py3" , smdataparallel_enabled ),
866- (None , "pytorch" , "1.6.0" , "py3" , smdataparallel_enabled ),
867- ("ml.p3.2xlarge" , "pytorch" , "1.6.0" , "py3" , smdataparallel_enabled ),
868862 ("ml.p3dn.24xlarge" , "pytorch" , "1.6.0" , "py2" , smdataparallel_enabled ),
869863 ("ml.p3.16xlarge" , "pytorch" , "1.5.0" , "py3" , smdataparallel_enabled ),
870864 ]
@@ -966,74 +960,6 @@ def test_validate_smdataparallel_args_not_raises():
966960 )
967961
968962
969- def test_validate_pytorchddp_not_raises ():
970- # Case 1: Framework is not PyTorch
971- fw_utils .validate_pytorch_distribution (
972- distribution = None ,
973- framework_name = "tensorflow" ,
974- framework_version = "2.9.1" ,
975- py_version = "py3" ,
976- image_uri = "custom-container" ,
977- )
978- # Case 2: Framework is PyTorch, but distribution is not PyTorchDDP
979- pytorchddp_disabled = {"pytorchddp" : {"enabled" : False }}
980- fw_utils .validate_pytorch_distribution (
981- distribution = pytorchddp_disabled ,
982- framework_name = "pytorch" ,
983- framework_version = "1.10" ,
984- py_version = "py3" ,
985- image_uri = "custom-container" ,
986- )
987- # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions
988- pytorchddp_enabled = {"pytorchddp" : {"enabled" : True }}
989- pytorchddp_supported_fw_versions = [
990- "1.10" ,
991- "1.10.0" ,
992- "1.10.2" ,
993- "1.11" ,
994- "1.11.0" ,
995- "1.12" ,
996- "1.12.0" ,
997- "1.12.1" ,
998- "1.13.1" ,
999- "2.0.0" ,
1000- "2.0.1" ,
1001- "2.1.0" ,
1002- "2.2.0" ,
1003- ]
1004- for framework_version in pytorchddp_supported_fw_versions :
1005- fw_utils .validate_pytorch_distribution (
1006- distribution = pytorchddp_enabled ,
1007- framework_name = "pytorch" ,
1008- framework_version = framework_version ,
1009- py_version = "py3" ,
1010- image_uri = "custom-container" ,
1011- )
1012-
1013-
1014- def test_validate_pytorchddp_raises ():
1015- pytorchddp_enabled = {"pytorchddp" : {"enabled" : True }}
1016- # Case 1: Unsupported framework version
1017- with pytest .raises (ValueError ):
1018- fw_utils .validate_pytorch_distribution (
1019- distribution = pytorchddp_enabled ,
1020- framework_name = "pytorch" ,
1021- framework_version = "1.8" ,
1022- py_version = "py3" ,
1023- image_uri = None ,
1024- )
1025-
1026- # Case 2: Unsupported Py version
1027- with pytest .raises (ValueError ):
1028- fw_utils .validate_pytorch_distribution (
1029- distribution = pytorchddp_enabled ,
1030- framework_name = "pytorch" ,
1031- framework_version = "1.10" ,
1032- py_version = "py2" ,
1033- image_uri = None ,
1034- )
1035-
1036-
1037963def test_validate_torch_distributed_not_raises ():
1038964 # Case 1: Framework is PyTorch, but torch_distributed is not enabled
1039965 torch_distributed_disabled = {"torch_distributed" : {"enabled" : False }}
0 commit comments