Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Sangkug Lym <[email protected]>
  • Loading branch information
erhoo82 committed Nov 30, 2023
1 parent a2bee19 commit 419ad62
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
global_batch_size=cfg.get('global_batch_size'),
rampup_batch_size=cfg.get('rampup_batch_size', None),
use_fp8=cfg.get('fp8', False),
use_sharp=cfg.get('sharp', False),
init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False),
seed=self.cfg.get('seed', 1234),
apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
Expand Down
2 changes: 0 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def initialize_model_parallel_for_nemo(
global_batch_size=None,
rampup_batch_size=None,
use_fp8=False,
use_sharp=False,
init_mpi_proc_group=False,
seed=1234,
apex_transformer_log_level=30,
Expand All @@ -85,7 +84,6 @@ def initialize_model_parallel_for_nemo(
app_state.pipeline_model_parallel_size = pipeline_model_parallel_size
app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
app_state.use_fp8 = use_fp8
app_state.use_sharp = use_sharp
app_state.init_mpi_proc_group = init_mpi_proc_group
(
app_state.tensor_model_parallel_rank,
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _training_strategy(self) -> NLPDDPStrategy:
no_ddp_communication_hook=True,
gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
sharp=cfg.model.get('sharp', False),
)

def _grad_scaler(self) -> GradScaler:
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class NLPDDPStrategy(DDPStrategy):
Args:
no_ddp_communication_hook: Disable DDP communication hook when using AMP-O2
with FP32 gradient accumulation.
sharp: Apply SHARP to data-parallel proc groups.
"""

def __init__(
Expand All @@ -89,6 +90,7 @@ def __init__(
cluster_environment: ClusterEnvironment = None,
checkpoint_io: Optional[CheckpointIO] = None,
no_ddp_communication_hook: bool = False,
sharp: bool = False,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
if not HAVE_APEX:
Expand All @@ -103,6 +105,7 @@ def __init__(
super().__init__(parallel_devices, cluster_environment, checkpoint_io, **kwargs)

self.no_ddp_communication_hook = no_ddp_communication_hook
self.sharp = sharp

def setup(self, trainer: "pl.Trainer") -> None:
"""
Expand Down Expand Up @@ -180,7 +183,6 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None:
Args:
global_rank (int): the global process index.
world_size (int): the total number of GPUs, num_nodes * num_devices
is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM.
"""
app_state = AppState()

Expand All @@ -196,7 +198,7 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None:
pipeline_model_parallel_size=app_state.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size,
pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank,
use_sharp=app_state.use_sharp,
use_sharp=self.sharp,
)

# assert that fake tp and pp rank match after model parallel init
Expand Down
17 changes: 0 additions & 17 deletions nemo/utils/app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(self):
self._data_parallel_group = None
self._megatron_checkpoint_version = None
self._use_fp8 = False
self._use_sharp = False
self._init_mpi_proc_gruop = False

self._random_seed = None
Expand Down Expand Up @@ -365,22 +364,6 @@ def use_fp8(self, use_fp8):
"""
self._use_fp8 = use_fp8

@property
def use_sharp(self):
""" Property returns the use of sharp.
Returns:
Use of sharp.
"""
return self._use_sharp

@use_sharp.setter
def use_sharp(self, use_sharp):
""" Property sets the use of sharp.
Args:
use_sharp: Use of sharp.
"""
self._use_sharp = use_sharp

@property
def init_mpi_proc_group(self):
""" Property sets the initialization of mpi process group.
Expand Down

0 comments on commit 419ad62

Please sign in to comment.