Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call ckpt_to_weights_subdir from MegatronCheckpointIO #10897

Merged
merged 36 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
dcb38f3
locate weights path within MegatronCheckpointIO
ashors1 Oct 15, 2024
6010377
small refactor
ashors1 Oct 15, 2024
2417023
remove another instance of ckpt_to_weights_subdir
ashors1 Oct 15, 2024
eed4bad
move ckpt_to_weights_subdir
ashors1 Oct 16, 2024
52c0ad3
Apply isort and black reformatting
ashors1 Oct 16, 2024
e5dbd61
Apply isort and black reformatting
artbataev Oct 16, 2024
45df47d
add weights path in save_checkpoint
ashors1 Oct 16, 2024
c49e2a6
fix circular import
ashors1 Oct 16, 2024
d3ffd5d
Apply isort and black reformatting
ashors1 Oct 16, 2024
ea49e20
handle saving in ckpt_to_weights_subdir
ashors1 Oct 16, 2024
c4c3fd5
fix minor typo
ashors1 Oct 16, 2024
3ae933e
bug fixes
ashors1 Oct 16, 2024
f1fbec5
fix undefined variable
ashors1 Oct 17, 2024
8161076
move function
ashors1 Oct 17, 2024
994719e
Apply isort and black reformatting
ashors1 Oct 17, 2024
ea51ab2
fix adapter meta file path
cuichenx Oct 17, 2024
871ac85
Apply isort and black reformatting
cuichenx Oct 17, 2024
f5889ca
Merge branch 'refs/heads/main' into ashors/ckpt-subdirs
cuichenx Oct 17, 2024
df2c4b1
Merge remote-tracking branch 'origin/ashors/ckpt-subdirs' into ashors…
cuichenx Oct 17, 2024
5aec05b
fix mixtral test
ashors1 Oct 18, 2024
2df54e3
fix mixtral test
ashors1 Oct 18, 2024
440a244
use function for weights subdir
cuichenx Oct 18, 2024
b2883a1
address comments
ashors1 Oct 18, 2024
26a8d8d
move asserts
ashors1 Oct 18, 2024
ac1779b
fix undefined vars
ashors1 Oct 21, 2024
f380df7
bug fix
ashors1 Oct 21, 2024
e15cafa
fix mixtral test
ashors1 Oct 22, 2024
3271f5c
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/ckpt-subdirs
ashors1 Oct 25, 2024
a14151c
fix fabric
ashors1 Oct 28, 2024
ce5da6f
fix typo
ashors1 Oct 29, 2024
2cdd672
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/ckpt-subdirs
ashors1 Oct 29, 2024
0954f6b
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/ckpt-subdirs
ashors1 Oct 30, 2024
e997e24
revert some unnecessary changes
ashors1 Oct 30, 2024
7ef6ab7
Apply isort and black reformatting
ashors1 Oct 30, 2024
d20d4a8
Merge branch 'main' into ashors/ckpt-subdirs
cuichenx Oct 31, 2024
9c3fa3c
fix for peft
cuichenx Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions nemo/lightning/ckpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ def idempotent_path_append(base_dir: Union[str, Path], suffix) -> Path:
return base_dir


def ckpt_to_weights_subdir(filepath: Union[str, Path]) -> Path:
"""Given an input checkpoint filepath, clean it using `ckpt_to_dir` and then return the weights subdirectory."""
base_dir = ckpt_to_dir(filepath=filepath)
return idempotent_path_append(base_dir, WEIGHTS_PATH)


def ckpt_to_context_subdir(filepath: Union[str, Path]) -> Path:
"""Given an input checkpoint filepath, clean it using `ckpt_to_dir` and then return the context subdirectory."""
base_dir = ckpt_to_dir(filepath=filepath)
Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch import nn
from typing_extensions import Self, override

from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.io.mixin import IOMixin, serialization, track_io

if TYPE_CHECKING:
Expand Down Expand Up @@ -83,7 +83,7 @@ def load_model(
model = context.model

dist_model = self.setup_module(model)
self.load(ckpt_to_weights_subdir(path), {"state_dict": dist_model})
self.load(path, {"state_dict": dist_model})

return dist_model

Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from filelock import FileLock, Timeout
from pytorch_lightning.trainer.states import TrainerFn

from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
nemo.lightning.ckpt_utils
begins an import cycle.

# Dynamically inherit from the correct Path subclass based on the operating system.
if os.name == 'nt':
Expand Down Expand Up @@ -198,7 +198,7 @@
trainer.strategy.setup(trainer)
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
trainer.save_checkpoint(ckpt_to_weights_subdir(output_path))
trainer.save_checkpoint(output_path)
if getattr(trainer.strategy, "async_save", False):
trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True)

Expand Down
30 changes: 28 additions & 2 deletions nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from torch import nn
from typing_extensions import Self, override

from nemo.lightning.ckpt_utils import ckpt_to_dir
from nemo.lightning.ckpt_utils import WEIGHTS_PATH, ckpt_to_dir

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
nemo.lightning.ckpt_utils
begins an import cycle.
from nemo.lightning.io.capture import IOProtocol
from nemo.lightning.io.mixin import IOMixin

Expand Down Expand Up @@ -78,6 +78,26 @@
return extra


def ckpt_to_weights_subdir(filepath: Union[str, Path], is_saving) -> Path:
"""Given an input checkpoint filepath, clean it using `ckpt_to_dir` and then return the weights subdirectory, if it exists."""
filepath = ckpt_to_dir(filepath=filepath)
base_dir = filepath
assert isinstance(base_dir, Path)
if base_dir.parts[-1] != WEIGHTS_PATH:
maybe_base_dir = base_dir / WEIGHTS_PATH
if maybe_base_dir.is_dir() or is_saving:
base_dir = maybe_base_dir
## handle adapter paths
if hasattr(base_dir, "base_model_path") and base_dir.base_model_path.parts[-1] != WEIGHTS_PATH:
maybe_base_model_path = base_dir.base_model_path / WEIGHTS_PATH
if maybe_base_model_path.is_dir() or is_saving:
base_dir.base_model_path = base_dir.base_model_path / WEIGHTS_PATH
if is_saving:
assert base_dir.parts[-1] == WEIGHTS_PATH
assert base_dir.parent == Path(filepath)
return base_dir


class MegatronCheckpointIO(AsyncCompatibleCheckpointIO, IOMixin):
"""CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints respectively,
common for most use cases.
Expand Down Expand Up @@ -132,7 +152,8 @@
f" storage_options, but {storage_options=} was provided."
f" Ignoring given storage_options"
)
checkpoint_dir = ckpt_to_dir(path)
checkpoint_dir = ckpt_to_weights_subdir(path, is_saving=True)
akoumpa marked this conversation as resolved.
Show resolved Hide resolved

fs = get_filesystem(checkpoint_dir)
if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir):
logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving')
Expand Down Expand Up @@ -188,6 +209,11 @@
if not fs.isdir(path):
raise ValueError(f"Distributed checkpoints should be a directory. Found: {path}.")

# Load from ckpt_path/weights (new format) if it exists
path = ckpt_to_weights_subdir(path, is_saving=False)
if hasattr(path, "base_model_path") and not path.base_model_path.exists():
akoumpa marked this conversation as resolved.
Show resolved Hide resolved
path.base_model_path = path.base_model_path.parent

if self.save_ckpt_format == 'zarr' and self.load_directly_on_device:
from megatron.core.dist_checkpointing.strategies.tensorstore import TensorStoreLoadShardedStrategy

Expand Down
22 changes: 10 additions & 12 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class ModelCheckpoint(PTLModelCheckpoint):
"""

UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"
WEIGHTS_PATH = "weights"

def __init__(
self,
Expand Down Expand Up @@ -436,7 +435,6 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)

# barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
# if anything goes wrong during checkpointing, we should be able to detect that data is incomplete.
ckpt_filepath = ckpt_to_dir(filepath) / ModelCheckpoint.WEIGHTS_PATH
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
ema_callback = self._ema_callback(trainer)

Expand All @@ -453,15 +451,15 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
if self.async_save:
raise ValueError('async_save with EMA not supported')
with ema_callback.save_original_optimizer_state(trainer):
super()._save_checkpoint(trainer, ckpt_filepath)
super()._save_checkpoint(trainer, filepath)

# save EMA copy of the model as well.
with ema_callback.save_ema_model(trainer):
rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}")
ckpt_filepath = self._ema_format_filepath(ckpt_filepath)
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
filepath = self._ema_format_filepath(filepath)
if self.verbose:
rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}")
super()._save_checkpoint(trainer, ckpt_filepath)
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
super()._save_checkpoint(trainer, filepath)
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)
else:
## Determine whether to include optimizer states in the checkpoint
Expand All @@ -487,7 +485,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
self.deferred_ckpts_to_remove.append([])
else:
storage_options = None
trainer.save_checkpoint(ckpt_filepath, save_weights_only, storage_options=storage_options)
trainer.save_checkpoint(filepath, save_weights_only, storage_options=storage_options)

if self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context", yaml_attrs=["model"])
Expand Down Expand Up @@ -596,11 +594,11 @@ def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None:
}

checkpoint_filepaths = {f.resolve() for f in checkpoint_dir.rglob("*.ckpt")}
for ckpt_filepath in checkpoint_filepaths:
possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_filepath)
for filepath in checkpoint_filepaths:
possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(filepath)
if possible_marker_path in existing_marker_filepaths:
logging.warning(f'Removing unfinished checkpoint: {ckpt_filepath}')
os.remove(ckpt_filepath)
logging.warning(f'Removing unfinished checkpoint: {filepath}')
os.remove(filepath)

# some directories might be distributed checkpoints, we remove these if they have a unfinished marker
all_dirpaths = {d.resolve() for d in checkpoint_dir.glob("*") if d.is_dir()}
Expand Down
9 changes: 1 addition & 8 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@

from nemo.core.optim.mcore_optim import McoreDistributedOptimizer
from nemo.lightning import _strategy_lib, io
from nemo.lightning.ckpt_utils import ckpt_to_weights_subdir
from nemo.lightning.megatron_parallel import (
CallbackConnector,
MegatronParallel,
Expand Down Expand Up @@ -703,13 +702,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore:
if self.lightning_module.optimizers(use_pl_optimizer=False):
sharded_state_dict["optimizer"] = [self.optimizer_sharded_state_dict(is_loading=True)]

# Load from ckpt_path/weights (new format) if it exists, otherwise load from ckpt_path (legacy format)
load_dir = ckpt_to_weights_subdir(checkpoint_path)
if not load_dir.exists():
load_dir = checkpoint_path
if isinstance(load_dir, AdapterPath) and not load_dir.base_model_path.exists():
load_dir.base_model_path = load_dir.base_model_path.parent
checkpoint = self.checkpoint_io.load_checkpoint(load_dir, sharded_state_dict=sharded_state_dict)
checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=sharded_state_dict)

return checkpoint

Expand Down
2 changes: 1 addition & 1 deletion tests/collections/llm/bitexact/mixtral/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ python3 /workspace/tests/collections/llm/bitexact/mixtral/pretrain_mini_mixtral.

# Compare outputs
python3 /workspace/tests/collections/llm/bitexact/mixtral/compare_ckpts.py \
"$NEMO_OUTPUT_PATH/checkpoints/--None=0.0000-epoch=0/" "$MCORE_OUTPUT_PATH/iter_0000010/"
"$NEMO_OUTPUT_PATH/checkpoints/--None=0.0000-epoch=0/weights" "$MCORE_OUTPUT_PATH/iter_0000010/"
2 changes: 1 addition & 1 deletion tests/collections/llm/megatron_mixtral_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def main(args):
)

# Confirm checkpoint directory structure
output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0/"
output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0/weights"
assert output_path.exists(), f"Expected {output_path} to exist"
assert output_path.is_dir(), f"Expected {output_path} to be a directory"
output_files = ['__0_0.distcp', '__0_1.distcp', 'common.pt', 'metadata.json', '.metadata']
Expand Down
Loading