145145 ],
146146}
147147
148- PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [
149- "1.10" ,
150- "1.10.0" ,
151- "1.10.2" ,
152- "1.11" ,
153- "1.11.0" ,
154- "1.12" ,
155- "1.12.0" ,
156- "1.12.1" ,
157- "1.13.1" ,
158- "2.0.0" ,
159- "2.0.1" ,
160- "2.1.0" ,
161- "2.2.0" ,
162- ]
163-
164148TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
165149 "1.13.1" ,
166150 "2.0.0" ,
@@ -795,7 +779,6 @@ def _validate_smdataparallel_args(
795779
796780 Raises:
797781 ValueError: if
798- (`instance_type` is not in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES or
799782 `py_version` is not python3 or
800783 `framework_version` is not in SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSION
801784 """
@@ -806,17 +789,10 @@ def _validate_smdataparallel_args(
806789 if not smdataparallel_enabled :
807790 return
808791
809- is_instance_type_supported = instance_type in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES
810-
811792 err_msg = ""
812793
813- if not is_instance_type_supported :
814- # instance_type is required
815- err_msg += (
816- f"Provided instance_type { instance_type } is not supported by smdataparallel.\n "
817- "Please specify one of the supported instance types:"
818- f"{ SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES } \n "
819- )
794+ if not instance_type :
795+ err_msg += "Please specify an instance_type for smdataparallel.\n "
820796
821797 if not image_uri :
822798 # ignore framework_version & py_version if image_uri is set
@@ -928,13 +904,6 @@ def validate_distribution(
928904 )
929905 if framework_name and framework_name == "pytorch" :
930906 # We need to validate only for PyTorch framework
931- validate_pytorch_distribution (
932- distribution = validated_distribution ,
933- framework_name = framework_name ,
934- framework_version = framework_version ,
935- py_version = py_version ,
936- image_uri = image_uri ,
937- )
938907 validate_torch_distributed_distribution (
939908 instance_type = instance_type ,
940909 distribution = validated_distribution ,
@@ -968,13 +937,6 @@ def validate_distribution(
968937 )
969938 if framework_name and framework_name == "pytorch" :
970939 # We need to validate only for PyTorch framework
971- validate_pytorch_distribution (
972- distribution = validated_distribution ,
973- framework_name = framework_name ,
974- framework_version = framework_version ,
975- py_version = py_version ,
976- image_uri = image_uri ,
977- )
978940 validate_torch_distributed_distribution (
979941 instance_type = instance_type ,
980942 distribution = validated_distribution ,
@@ -1023,63 +985,6 @@ def validate_distribution_for_instance_type(instance_type, distribution):
1023985 raise ValueError (err_msg )
1024986
1025987
1026- def validate_pytorch_distribution (
1027- distribution , framework_name , framework_version , py_version , image_uri
1028- ):
1029- """Check if pytorch distribution strategy is correctly invoked by the user.
1030-
1031- Args:
1032- distribution (dict): A dictionary with information to enable distributed training.
1033- (Defaults to None if distributed training is not enabled.) For example:
1034-
1035- .. code:: python
1036-
1037- {
1038- "pytorchddp": {
1039- "enabled": True
1040- }
1041- }
1042- framework_name (str): A string representing the name of framework selected.
1043- framework_version (str): A string representing the framework version selected.
1044- py_version (str): A string representing the python version selected.
1045- image_uri (str): A string representing a Docker image URI.
1046-
1047- Raises:
1048- ValueError: if
1049- `py_version` is not python3 or
1050- `framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS
1051- """
1052- if framework_name and framework_name != "pytorch" :
1053- # We need to validate only for PyTorch framework
1054- return
1055-
1056- pytorch_ddp_enabled = False
1057- if "pytorchddp" in distribution :
1058- pytorch_ddp_enabled = distribution .get ("pytorchddp" ).get ("enabled" , False )
1059- if not pytorch_ddp_enabled :
1060- # Distribution strategy other than pytorchddp is selected
1061- return
1062-
1063- err_msg = ""
1064- if not image_uri :
1065- # ignore framework_version and py_version if image_uri is set
1066- # in case image_uri is not set, then both are mandatory
1067- if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS :
1068- err_msg += (
1069- f"Provided framework_version { framework_version } is not supported by"
1070- " pytorchddp.\n "
1071- "Please specify one of the supported framework versions:"
1072- f" { PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS } \n "
1073- )
1074- if "py3" not in py_version :
1075- err_msg += (
1076- f"Provided py_version { py_version } is not supported by pytorchddp.\n "
1077- "Please specify py_version>=py3"
1078- )
1079- if err_msg :
1080- raise ValueError (err_msg )
1081-
1082-
1083988def validate_torch_distributed_distribution (
1084989 instance_type ,
1085990 distribution ,
0 commit comments