Skip to content

Commit

Permalink
Call ckpt_to_weights_subdir from MegatronCheckpointIO (#10897)
Browse files Browse the repository at this point in the history
* locate weights path within MegatronCheckpointIO

Signed-off-by: ashors1 <[email protected]>

* small refactor

Signed-off-by: ashors1 <[email protected]>

* remove another instance of ckpt_to_weights_subdir

Signed-off-by: ashors1 <[email protected]>

* move ckpt_to_weights_subdir

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: artbataev <[email protected]>

* add weights path in save_checkpoint

Signed-off-by: ashors1 <[email protected]>

* fix circular import

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* handle saving in ckpt_to_weights_subdir

Signed-off-by: ashors1 <[email protected]>

* fix minor typo

Signed-off-by: ashors1 <[email protected]>

* bug fixes

Signed-off-by: ashors1 <[email protected]>

* fix undefined variable

Signed-off-by: ashors1 <[email protected]>

* move function

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* fix adapter meta file path

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

* fix mixtral test

Signed-off-by: ashors1 <[email protected]>

* fix mixtral test

Signed-off-by: ashors1 <[email protected]>

* use function for weights subdir

Signed-off-by: Chen Cui <[email protected]>

* address comments

Signed-off-by: ashors1 <[email protected]>

* move asserts

Signed-off-by: ashors1 <[email protected]>

* fix undefined vars

Signed-off-by: ashors1 <[email protected]>

* bug fix

Signed-off-by: ashors1 <[email protected]>

---------

Signed-off-by: ashors1 <[email protected]>
Signed-off-by: ashors1 <[email protected]>
Signed-off-by: artbataev <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: ashors1 <[email protected]>
Co-authored-by: artbataev <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: cuichenx <[email protected]>
  • Loading branch information
5 people authored and yashaswikarnati committed Nov 21, 2024
1 parent 452f8f7 commit 84cfaab
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 36 deletions.
6 changes: 0 additions & 6 deletions nemo/lightning/ckpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ def idempotent_path_append(base_dir: Union[str, Path], suffix) -> Path:
return base_dir


def ckpt_to_weights_subdir(filepath: Union[str, Path]) -> Path:
"""Given an input checkpoint filepath, clean it using `ckpt_to_dir` and then return the weights subdirectory."""
base_dir = ckpt_to_dir(filepath=filepath)
return idempotent_path_append(base_dir, WEIGHTS_PATH)


def ckpt_to_context_subdir(filepath: Union[str, Path]) -> Path:
"""Given an input checkpoint filepath, clean it using `ckpt_to_dir` and then return the context subdirectory."""
base_dir = ckpt_to_dir(filepath=filepath)
Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch import nn
from typing_extensions import Self, override

from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.io.mixin import IOMixin, serialization, track_io

if TYPE_CHECKING:
Expand Down Expand Up @@ -83,7 +83,7 @@ def load_model(
model = context.model

dist_model = self.setup_module(model)
self.load(ckpt_to_weights_subdir(path), {"state_dict": dist_model})
self.load(path, {"state_dict": dist_model})

return dist_model

Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from filelock import FileLock, Timeout
from pytorch_lightning.trainer.states import TrainerFn

from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir

# Dynamically inherit from the correct Path subclass based on the operating system.
if os.name == 'nt':
Expand Down Expand Up @@ -198,7 +198,7 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer, dump_io: bool = True
trainer.strategy.setup(trainer)
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
trainer.save_checkpoint(ckpt_to_weights_subdir(output_path))
trainer.save_checkpoint(output_path)
if getattr(trainer.strategy, "async_save", False):
trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True)

Expand Down
30 changes: 28 additions & 2 deletions nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from torch import nn
from typing_extensions import Self, override

from nemo.lightning.ckpt_utils import ckpt_to_dir
from nemo.lightning.ckpt_utils import WEIGHTS_PATH, ckpt_to_dir
from nemo.lightning.io.capture import IOProtocol
from nemo.lightning.io.mixin import IOMixin

Expand Down Expand Up @@ -78,6 +78,26 @@ def construct_extra(cls, trainer: pl.Trainer) -> Dict[str, Any]:
return extra


def ckpt_to_weights_subdir(filepath: Union[str, Path], is_saving) -> Path:
"""Given an input checkpoint filepath, clean it using `ckpt_to_dir` and then return the weights subdirectory, if it exists."""
filepath = ckpt_to_dir(filepath=filepath)
base_dir = filepath
assert isinstance(base_dir, Path)
if base_dir.parts[-1] != WEIGHTS_PATH:
maybe_base_dir = base_dir / WEIGHTS_PATH
if maybe_base_dir.is_dir() or is_saving:
base_dir = maybe_base_dir
## handle adapter paths
if hasattr(base_dir, "base_model_path") and base_dir.base_model_path.parts[-1] != WEIGHTS_PATH:
maybe_base_model_path = base_dir.base_model_path / WEIGHTS_PATH
if maybe_base_model_path.is_dir() or is_saving:
base_dir.base_model_path = base_dir.base_model_path / WEIGHTS_PATH
if is_saving:
assert base_dir.parts[-1] == WEIGHTS_PATH
assert base_dir.parent == Path(filepath)
return base_dir


class MegatronCheckpointIO(AsyncCompatibleCheckpointIO, IOMixin):
"""CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints respectively,
common for most use cases.
Expand Down Expand Up @@ -132,7 +152,8 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
f" storage_options, but {storage_options=} was provided."
f" Ignoring given storage_options"
)
checkpoint_dir = ckpt_to_dir(path)
checkpoint_dir = ckpt_to_weights_subdir(path, is_saving=True)

fs = get_filesystem(checkpoint_dir)
if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir):
logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving')
Expand Down Expand Up @@ -180,6 +201,11 @@ def load_checkpoint(
if not fs.isdir(path):
raise ValueError(f"Distributed checkpoints should be a directory. Found: {path}.")

# Load from ckpt_path/weights (new format) if it exists
path = ckpt_to_weights_subdir(path, is_saving=False)
if hasattr(path, "base_model_path") and not path.base_model_path.exists():
path.base_model_path = path.base_model_path.parent

if self.save_ckpt_format == 'zarr' and self.load_directly_on_device:
from megatron.core.dist_checkpointing.strategies.tensorstore import TensorStoreLoadShardedStrategy

Expand Down
22 changes: 10 additions & 12 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class ModelCheckpoint(PTLModelCheckpoint):
"""

UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"
WEIGHTS_PATH = "weights"

def __init__(
self,
Expand Down Expand Up @@ -438,7 +437,6 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)

# barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
# if anything goes wrong during checkpointing, we should be able to detect that data is incomplete.
ckpt_filepath = ckpt_to_dir(filepath) / ModelCheckpoint.WEIGHTS_PATH
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
ema_callback = self._ema_callback(trainer)

Expand All @@ -455,15 +453,15 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
if self.async_save:
raise ValueError('async_save with EMA not supported')
with ema_callback.save_original_optimizer_state(trainer):
super()._save_checkpoint(trainer, ckpt_filepath)
super()._save_checkpoint(trainer, filepath)

# save EMA copy of the model as well.
with ema_callback.save_ema_model(trainer):
rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}")
ckpt_filepath = self._ema_format_filepath(ckpt_filepath)
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
filepath = self._ema_format_filepath(filepath)
if self.verbose:
rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}")
super()._save_checkpoint(trainer, ckpt_filepath)
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
super()._save_checkpoint(trainer, filepath)
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)
else:
## Determine whether to include optimizer states in the checkpoint
Expand All @@ -489,7 +487,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
self.deferred_ckpts_to_remove.append([])
else:
storage_options = None
trainer.save_checkpoint(ckpt_filepath, save_weights_only, storage_options=storage_options)
trainer.save_checkpoint(filepath, save_weights_only, storage_options=storage_options)

if self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context", yaml_attrs=["model"])
Expand Down Expand Up @@ -598,11 +596,11 @@ def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None:
}

checkpoint_filepaths = {f.resolve() for f in checkpoint_dir.rglob("*.ckpt")}
for ckpt_filepath in checkpoint_filepaths:
possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_filepath)
for filepath in checkpoint_filepaths:
possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(filepath)
if possible_marker_path in existing_marker_filepaths:
logging.warning(f'Removing unfinished checkpoint: {ckpt_filepath}')
os.remove(ckpt_filepath)
logging.warning(f'Removing unfinished checkpoint: {filepath}')
os.remove(filepath)

# some directories might be distributed checkpoints, we remove these if they have a unfinished marker
all_dirpaths = {d.resolve() for d in checkpoint_dir.glob("*") if d.is_dir()}
Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME
from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.io.pl import ckpt_to_dir
from nemo.lightning.io.pl import ckpt_to_dir, ckpt_to_weights_subdir
from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
Expand Down Expand Up @@ -346,7 +346,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio

if is_global_rank_zero():
metadata = {"model_ckpt_path": str(self.model_ckpt_path)}
base_dir = ckpt_to_dir(path)
base_dir = ckpt_to_weights_subdir(path, is_saving=True)
base_dir.mkdir(parents=True, exist_ok=True)
adapter_meta_path = base_dir / ADAPTER_META_FILENAME
with open(adapter_meta_path, "w") as f:
Expand Down
9 changes: 1 addition & 8 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@

from nemo.core.optim.mcore_optim import McoreDistributedOptimizer
from nemo.lightning import _strategy_lib, io
from nemo.lightning.ckpt_utils import ckpt_to_weights_subdir
from nemo.lightning.megatron_parallel import (
CallbackConnector,
MegatronParallel,
Expand Down Expand Up @@ -703,13 +702,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore:
if self.lightning_module.optimizers(use_pl_optimizer=False):
sharded_state_dict["optimizer"] = [self.optimizer_sharded_state_dict(is_loading=True)]

# Load from ckpt_path/weights (new format) if it exists, otherwise load from ckpt_path (legacy format)
load_dir = ckpt_to_weights_subdir(checkpoint_path)
if not load_dir.exists():
load_dir = checkpoint_path
if isinstance(load_dir, AdapterPath) and not load_dir.base_model_path.exists():
load_dir.base_model_path = load_dir.base_model_path.parent
checkpoint = self.checkpoint_io.load_checkpoint(load_dir, sharded_state_dict=sharded_state_dict)
checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=sharded_state_dict)

return checkpoint

Expand Down
2 changes: 1 addition & 1 deletion tests/collections/llm/bitexact/mixtral/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ python3 /workspace/tests/collections/llm/bitexact/mixtral/pretrain_mini_mixtral.

# Compare outputs
python3 /workspace/tests/collections/llm/bitexact/mixtral/compare_ckpts.py \
"$NEMO_OUTPUT_PATH/checkpoints/--None=0.0000-epoch=0/" "$MCORE_OUTPUT_PATH/iter_0000010/"
"$NEMO_OUTPUT_PATH/checkpoints/--None=0.0000-epoch=0/weights" "$MCORE_OUTPUT_PATH/iter_0000010/"
2 changes: 1 addition & 1 deletion tests/collections/llm/megatron_mixtral_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def main(args):
)

# Confirm checkpoint directory structure
output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0/"
output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0/weights"
assert output_path.exists(), f"Expected {output_path} to exist"
assert output_path.is_dir(), f"Expected {output_path} to be a directory"
output_files = ['__0_0.distcp', '__0_1.distcp', 'common.pt', 'metadata.json', '.metadata']
Expand Down

0 comments on commit 84cfaab

Please sign in to comment.