2121
2222from __future__ import absolute_import
2323
24- from typing import Optional , Union , Dict , Any , List
24+ from typing import Optional , Union
2525from pydantic import BaseModel , model_validator
2626
2727import sagemaker_core .shapes as shapes
5454 CheckpointConfig ,
5555)
5656
57- from sagemaker .modules import logger
5857from sagemaker .modules .utils import convert_unassigned_to_none
5958
6059__all__ = [
61- "SourceCodeConfig" ,
62- "TorchDistributionConfig" ,
63- "MPIDistributionConfig" ,
64- "SMDistributedSettings" ,
65- "DistributionConfig" ,
60+ "SourceCode" ,
6661 "StoppingCondition" ,
6762 "RetryStrategy" ,
6863 "OutputDataConfig" ,
8782 "InstanceGroup" ,
8883 "TensorBoardOutputConfig" ,
8984 "CheckpointConfig" ,
90- "ComputeConfig " ,
91- "NetworkingConfig " ,
85+ "Compute " ,
86+ "Networking " ,
9287 "InputData" ,
9388]
9489
9590
96- class SMDistributedSettings (BaseModel ):
97- """SMDistributedSettings .
91+ class SourceCode (BaseModel ):
92+ """SourceCode .
9893
99- The SMDistributedSettings is used to configure distributed training when
100- using the smdistributed library.
101-
102- Attributes:
103- enable_dataparallel (Optional[bool]):
104- Whether to enable data parallelism.
105- enable_modelparallel (Optional[bool]):
106- Whether to enable model parallelism.
107- modelparallel_parameters (Optional[Dict[str, Any]]):
108- The parameters for model parallelism.
109- """
110-
111- enable_dataparallel : Optional [bool ] = False
112- enable_modelparallel : Optional [bool ] = False
113- modelparallel_parameters : Optional [Dict [str , Any ]] = None
114-
115-
116- class DistributionConfig (BaseModel ):
117- """Base class for distribution configurations."""
118-
119- _distribution_type : str
120-
121-
122- class TorchDistributionConfig (DistributionConfig ):
123- """TorchDistributionConfig.
124-
125- The TorchDistributionConfig uses `torchrun` or `torch.distributed.launch` in the backend to
126- launch distributed training.
127-
128- SMDistributed Library Information:
129- - `TorchDistributionConfig` can be used for SMModelParallel V2.
130- - For SMDataParallel or SMModelParallel V1, it is recommended to use the
131- `MPIDistributionConfig.`
132-
133-
134- Attributes:
135- smdistributed_settings (Optional[SMDistributedSettings]):
136- The settings for smdistributed library.
137- process_count_per_node (int):
138- The number of processes to run on each node in the training job.
139- Will default to the number of CPUs or GPUs available in the container.
140- """
141-
142- _distribution_type : str = "torch_distributed"
143-
144- smdistributed_settings : Optional [SMDistributedSettings ] = None
145- process_count_per_node : Optional [int ] = None
146-
147- @model_validator (mode = "after" )
148- def _validate_model (cls , model ): # pylint: disable=E0213
149- """Validate the model."""
150- if (
151- getattr (model , "smddistributed_settings" , None )
152- and model .smddistributed_settings .enable_dataparallel
153- ):
154- logger .warning (
155- "For smdistributed data parallelism, it is recommended to use "
156- + "MPIDistributionConfig."
157- )
158- return model
159-
160-
161- class MPIDistributionConfig (DistributionConfig ):
162- """MPIDistributionConfig.
163-
164- The MPIDistributionConfig uses `mpirun` in the backend to launch distributed training.
165-
166- SMDistributed Library Information:
167- - `MPIDistributionConfig` can be used for SMDataParallel and SMModelParallel V1.
168- - For SMModelParallel V2, it is recommended to use the `TorchDistributionConfig`.
169-
170- Attributes:
171- smdistributed_settings (Optional[SMDistributedSettings]):
172- The settings for smdistributed library.
173- process_count_per_node (int):
174- The number of processes to run on each node in the training job.
175- Will default to the number of CPUs or GPUs available in the container.
176- mpi_additional_options (Optional[str]):
177- The custom MPI options to use for the training job.
178- """
179-
180- _distribution_type : str = "mpi"
181-
182- smdistributed_settings : Optional [SMDistributedSettings ] = None
183- process_count_per_node : Optional [int ] = None
184- mpi_additional_options : Optional [List [str ]] = None
185-
186-
187- class SourceCodeConfig (BaseModel ):
188- """SourceCodeConfig.
189-
190- This config allows the user to specify the source code location, dependencies,
94+ The SourceCode class allows the user to specify the source code location, dependencies,
19195 entry script, or commands to be executed in the training job container.
19296
19397 Attributes:
@@ -210,10 +114,10 @@ class SourceCodeConfig(BaseModel):
210114 command : Optional [str ] = None
211115
212116
213- class ComputeConfig (shapes .ResourceConfig ):
214- """ComputeConfig .
117+ class Compute (shapes .ResourceConfig ):
118+ """Compute .
215119
216- The ComputeConfig is a subclass of `sagemaker_core.shapes.ResourceConfig`
120+ The Compute class is a subclass of `sagemaker_core.shapes.ResourceConfig`
217121 and allows the user to specify the compute resources for the training job.
218122
219123 Attributes:
@@ -245,7 +149,7 @@ class ComputeConfig(shapes.ResourceConfig):
245149 enable_managed_spot_training : Optional [bool ] = None
246150
247151 @model_validator (mode = "after" )
248- def _model_validator (self ) -> "ComputeConfig " :
152+ def _model_validator (self ) -> "Compute " :
249153 """Convert Unassigned values to None."""
250154 return convert_unassigned_to_none (self )
251155
@@ -259,10 +163,10 @@ def _to_resource_config(self) -> shapes.ResourceConfig:
259163 return shapes .ResourceConfig (** filtered_dict )
260164
261165
262- class NetworkingConfig (shapes .VpcConfig ):
263- """NetworkingConfig .
166+ class Networking (shapes .VpcConfig ):
167+ """Networking .
264168
265- The NetworkingConifg is a subclass of `sagemaker_core.shapes.VpcConfig ` and
169+ The Networking class is a subclass of `sagemaker_core.shapes.VpcConfig ` and
266170 allows the user to specify the networking configuration for the training job.
267171
268172 Attributes:
@@ -290,7 +194,7 @@ class NetworkingConfig(shapes.VpcConfig):
290194 enable_inter_container_traffic_encryption : Optional [bool ] = None
291195
292196 @model_validator (mode = "after" )
293- def _model_validator (self ) -> "NetworkingConfig " :
197+ def _model_validator (self ) -> "Networking " :
294198 """Convert Unassigned values to None."""
295199 return convert_unassigned_to_none (self )
296200
0 commit comments