Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect checkpoint removal logic #9192

Merged
merged 2 commits into from
May 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 59 additions & 44 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
def init_model_parallel(
sharp: bool, nccl_communicator_config_path: str = None, distributed_timeout_minutes: int = 30
) -> None:
""" Initializes Megatron-LM model parallel if using model parallelism.
"""Initializes Megatron-LM model parallel if using model parallelism.

Args:
sharp: Apply SHARP to NCCL data-parallel communication.
Expand Down Expand Up @@ -164,7 +164,7 @@ def init_model_parallel(


class NLPDDPStrategy(DDPStrategy):
""" DDP plugin for Pytorch Lightning. Needed to customize DDP for model parallel models.
"""DDP plugin for Pytorch Lightning. Needed to customize DDP for model parallel models.

Args:
no_ddp_communication_hook: Disable DDP communication hook when using AMP-O2
Expand Down Expand Up @@ -231,8 +231,8 @@ def setup_distributed(self, global_rank: int = None, world_size: int = None) ->
)

def configure_ddp(self):
""" Override LightningModule ddp if using model parallel.
Sets find_unused_parameters to False to use activation-checkpoint-recomputation.
"""Override LightningModule ddp if using model parallel.
Sets find_unused_parameters to False to use activation-checkpoint-recomputation.
"""

if (hasattr(self.model, 'megatron_amp_O2') and self.model.megatron_amp_O2) or (
Expand Down Expand Up @@ -406,7 +406,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr
self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=strict)

def _fix_tensors_device(self, ckpt: Dict) -> Dict:
""" Ensure checkpoint tensors are on the correct device."""
"""Ensure checkpoint tensors are on the correct device."""
assert torch.cuda.is_initialized(), (torch.cuda.is_available(), torch.cuda.is_initialized())
cur_dev = torch.device("cuda", index=torch.cuda.current_device())

Expand All @@ -418,10 +418,10 @@ def _fix_device(t):
return dict_list_map_outplace(_fix_device, ckpt)

def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
""" PTL method which we override to integrate distributed checkpoints for model parallel models.
In order to load distributed checkpoints we need to provide the sharded_state_dict to
the distributed load function. We get the sharded_state_dict from self.lightning_module
which makes it convenient to have the loading logic happen at the strategy level.
"""PTL method which we override to integrate distributed checkpoints for model parallel models.
In order to load distributed checkpoints we need to provide the sharded_state_dict to
the distributed load function. We get the sharded_state_dict from self.lightning_module
which makes it convenient to have the loading logic happen at the strategy level.
"""

fs = get_filesystem(checkpoint_path)
Expand Down Expand Up @@ -452,8 +452,9 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:

def remove_checkpoint(self, filepath: Union[str, Path]) -> None:
# check if filepath is a distributed checkpoint
if self.use_distributed_checkpointing and self.is_global_zero:
self.checkpoint_io.remove_checkpoint(ckpt_to_dir(filepath))
if self.use_distributed_checkpointing:
if self.is_global_zero:
self.checkpoint_io.remove_checkpoint(ckpt_to_dir(filepath))

# legacy checkpoint logic, does not use megatron core
else:
Expand Down Expand Up @@ -500,15 +501,15 @@ def distributed_sampler_kwargs(self):

@property
def restore_checkpoint_after_setup(self) -> bool:
""" This needs to be True for distributed checkpointing because
we require the model to have configured the optimizer before
deserializing the checkpoint.
"""This needs to be True for distributed checkpointing because
we require the model to have configured the optimizer before
deserializing the checkpoint.
"""
return True


class NLPDDPStrategyNotebook(NLPDDPStrategy):
""" Version of NLPDDPStrategy to be used in a Jupyter Notebook
"""Version of NLPDDPStrategy to be used in a Jupyter Notebook
A large portion of Megatron code has DDP dependency, so it has been necessary to use NLPDDPStrategy even for
single-GPU training (e.g. in a Jupyter notebook)
A PTL 2.0 changes has prevented DDPStrategy to be used in a notebook.
Expand Down Expand Up @@ -546,7 +547,7 @@ def _get_full_state_dict_context(module: torch.nn.Module, rank0_only: bool = Fal


class NLPFSDPStrategy(FSDPStrategy):
""" FSDP plugin for Pytorch Lightning with the support for tensor-parallelism.
"""FSDP plugin for Pytorch Lightning with the support for tensor-parallelism.

Args:
sharding_strategy: FSDP parameter sharding strategy.
Expand Down Expand Up @@ -639,7 +640,11 @@ def _set_mixed_precision_recipe(
reduce_dtype = utils_funcs.torch_dtype_from_precision(grad_reduce_dtype, None)
if set_buffer_dtype is not None:
buffer_dtype = utils_funcs.torch_dtype_from_precision(buffer_dtype, None)
return MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype,)
return MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype,
)

def setup_environment(self) -> None:
"""
Expand Down Expand Up @@ -750,15 +755,19 @@ def _get_osd(opt_state):
with FSDP.summon_full_params(self.model, writeback=True, rank0_only=False):
# rekey the osd stored from non-FSDP model
rekeyed_osd = FSDP.rekey_optim_state_dict(
temp_osd, OptimStateKeyType.PARAM_NAME, self.model,
temp_osd,
OptimStateKeyType.PARAM_NAME,
self.model,
)
temp_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, self.model)
except Exception as e:
print(f"Failed to load optimzier state dicts. Errored with {e}")
exit(1)
# Shard optimizer state dict
sharded_osd = FSDP.optim_state_dict_to_load(
optim_state_dict=temp_osd, model=self.model, optim=optimizer,
optim_state_dict=temp_osd,
model=self.model,
optim=optimizer,
)

optimizer.load_state_dict(sharded_osd)
Expand All @@ -767,9 +776,9 @@ def _get_osd(opt_state):
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
""" Store checkpoints
1. In case of sharded checkpoint, all ranks store unique checkpoints.
2. In case of non-sharded checkpoint, all data-parallel rank 0 store checkpoints.
"""Store checkpoints
1. In case of sharded checkpoint, all ranks store unique checkpoints.
2. In case of non-sharded checkpoint, all data-parallel rank 0 store checkpoints.
"""
app_state = AppState()
filepath = inject_model_parallel_rank(filepath, fsdp_sharded_ckpt=self.sharded_checkpoint)
Expand All @@ -780,8 +789,7 @@ def save_checkpoint(
self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)

def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
""" Load checkpoints
"""
"""Load checkpoints"""
# 1. Load normal or FSDP-sharded checkpoints.
fs = get_filesystem(checkpoint_path)

Expand All @@ -798,8 +806,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
return checkpoint

def remove_checkpoint(self, filepath: Union[str, Path]) -> None:
""" Remove checkpoints
"""
"""Remove checkpoints"""
# legacy checkpoint logic, does not use megatron core
app_state = AppState()
# PTL override to accomodate model parallel checkpoints
Expand All @@ -814,9 +821,9 @@ def remove_checkpoint(self, filepath: Union[str, Path]) -> None:

@property
def restore_checkpoint_after_setup(self) -> bool:
""" When loading FSDP-sharded checkpoint, need to restore checkpoint after configuring
FSDP sharding to match FSDP-sharded format between the checkpoint and the current
model and optimizer.
"""When loading FSDP-sharded checkpoint, need to restore checkpoint after configuring
FSDP sharding to match FSDP-sharded format between the checkpoint and the current
model and optimizer.
"""
return True

Expand Down Expand Up @@ -915,7 +922,8 @@ def dummy():
else:
# move weights to the tmpdir
for tp_rank, pp_rank in itertools.product(
range(app_state.tensor_model_parallel_size), range(app_state.pipeline_model_parallel_size),
range(app_state.tensor_model_parallel_size),
range(app_state.pipeline_model_parallel_size),
):
os.makedirs(os.path.join(tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}'))
mp_model_weights = os.path.join(
Expand Down Expand Up @@ -1000,6 +1008,7 @@ def modify_state_dict(self, conf, state_dict):
loaded_keys = state_dict.keys()
if 'model.model.diffusion_model.input_blocks.1.0.in_layers.2.weight' in loaded_keys:
new_state_dict = {}

# GroupNormOpt fuses activation function to one layer, thus the indexing of weights are shifted for following
def should_process(key):
base_str = "model.model.diffusion_model."
Expand Down Expand Up @@ -1110,7 +1119,13 @@ def restore_from(
# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .nemo behavior)
loaded_params = super().load_config_and_state_dict(
calling_cls, restore_path, override_config_path, map_location, strict, return_config, trainer,
calling_cls,
restore_path,
override_config_path,
map_location,
strict,
return_config,
trainer,
)
if not isinstance(loaded_params, tuple) or return_config is True:
return loaded_params
Expand Down Expand Up @@ -1165,12 +1180,12 @@ def dummy():


class PipelineMixedPrecisionPlugin(MixedPrecisionPlugin):
""" Overrides PTL autocasting to not wrap training/val/test_step.
We do this because we have the megatron-core fwd/bwd functions in training_step.
This means .backward is being called in training_step so we do not want the whole
step wrapped in autocast.
"""Overrides PTL autocasting to not wrap training/val/test_step.
We do this because we have the megatron-core fwd/bwd functions in training_step.
This means .backward is being called in training_step so we do not want the whole
step wrapped in autocast.

We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions.
We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions.
"""

def __init__(
Expand Down Expand Up @@ -1206,12 +1221,12 @@ def forward_context(self) -> Generator[None, None, None]:


class FSDPMixedPrecisionPlugin(FSDPPrecision):
""" Overrides PTL autocasting to not wrap training/val/test_step.
We do this because we have the megatron-core fwd/bwd functions in training_step.
This means .backward is being called in training_step so we do not want the whole
step wrapped in autocast.
"""Overrides PTL autocasting to not wrap training/val/test_step.
We do this because we have the megatron-core fwd/bwd functions in training_step.
This means .backward is being called in training_step so we do not want the whole
step wrapped in autocast.

We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions.
We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions.
"""

def __init__(
Expand Down Expand Up @@ -1246,7 +1261,7 @@ class GradScaler(torch.cuda.amp.GradScaler):

def __init__(
self,
init_scale=2.0 ** 16,
init_scale=2.0**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
Expand Down Expand Up @@ -1500,15 +1515,15 @@ def optimizer_step(

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
""" No explicit precision casting. Inputs are supposed to be manually casted """
"""No explicit precision casting. Inputs are supposed to be manually casted"""
try:
yield
finally:
pass


class GlobalBatchDataFetcher(_DataFetcher):
""" Overrides PTL DataFetcher. Used to fetch global batches."""
"""Overrides PTL DataFetcher. Used to fetch global batches."""

def __init__(self, prefetch_batches: int = 0, store_on_device: bool = False) -> None:

Expand Down
Loading