Skip to content

Commit

Permalink
SFT profile start and end step fix
Browse files Browse the repository at this point in the history
Signed-off-by: Sangkug Lym <[email protected]>
  • Loading branch information
erhoo82 committed Dec 2, 2023
1 parent acf1d9b commit 2ecf71a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
else:
base_module = self.model

# Set the profile start and end steps in the unit of global batach
if hasattr(self, '_nsys_profile_enabled'):
self._nsys_profile_start_step = self.cfg.nsys_profile.get('start_step', 0)
self._nsys_profile_end_step = self.cfg.nsys_profile.get('end_step', 0)

self._reset_activation_checkpointing_args()
self._reset_sequence_parallelism_args()
self.virtual_tokens = 0

def setup_metric(self, data_cfg):
Expand Down Expand Up @@ -587,7 +591,6 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
# Merge the functionality of previous on_inference_epoch_end() within inference_epoch_end() func here
app_state = AppState()
self._restore_activation_checkpointing_args()
self._restore_sequence_parallelism_args()
if hasattr(self, "_train_ds"):
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
Expand Down Expand Up @@ -810,7 +813,6 @@ def setup_eval_dataloader(self, datasets, data_cfg):

def on_validation_epoch_start(self):
self._reset_activation_checkpointing_args()
self._reset_sequence_parallelism_args()
app_state = AppState()
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
Expand All @@ -823,7 +825,6 @@ def on_validation_epoch_start(self):

def on_test_epoch_start(self):
self._reset_activation_checkpointing_args()
self._reset_sequence_parallelism_args()
app_state = AppState()
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
Expand Down
8 changes: 6 additions & 2 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Setup nsys profiling if it has been enabled in the model config
self._setup_nsys_profiling()

# A flag for the profile generation
self._profile_complete = False

def __init_subclass__(cls) -> None:
cls._save_restore_connector = SaveRestoreConnector()

Expand Down Expand Up @@ -1720,7 +1723,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> O
# nsys profiling
if self.device.type == 'cuda':
if hasattr(self, '_nsys_profile_enabled'):
if self._nsys_profile_enabled:
if self._nsys_profile_enabled and not self._profile_complete:
if batch_idx == self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks:
logging.info("====== Start nsys profiling ======")
torch.cuda.cudart().cudaProfilerStart()
Expand Down Expand Up @@ -1757,10 +1760,11 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int =

if self.device.type == 'cuda':
if hasattr(self, '_nsys_profile_enabled'):
if self._nsys_profile_enabled:
if self._nsys_profile_enabled and not self._profile_complete:
if batch_idx == self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks:
logging.info("====== End nsys profiling ======")
torch.cuda.cudart().cudaProfilerStop()
self._profile_complete = True

def _cleanup_on_execution_end(self):
"""
Expand Down

0 comments on commit 2ecf71a

Please sign in to comment.