Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call ckpt_to_weights_subdir from MegatronCheckpointIO #10897

Merged
merged 36 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
dcb38f3
locate weights path within MegatronCheckpointIO
ashors1 Oct 15, 2024
6010377
small refactor
ashors1 Oct 15, 2024
2417023
remove another instance of ckpt_to_weights_subdir
ashors1 Oct 15, 2024
eed4bad
move ckpt_to_weights_subdir
ashors1 Oct 16, 2024
52c0ad3
Apply isort and black reformatting
ashors1 Oct 16, 2024
e5dbd61
Apply isort and black reformatting
artbataev Oct 16, 2024
45df47d
add weights path in save_checkpoint
ashors1 Oct 16, 2024
c49e2a6
fix circular import
ashors1 Oct 16, 2024
d3ffd5d
Apply isort and black reformatting
ashors1 Oct 16, 2024
ea49e20
handle saving in ckpt_to_weights_subdir
ashors1 Oct 16, 2024
c4c3fd5
fix minor typo
ashors1 Oct 16, 2024
3ae933e
bug fixes
ashors1 Oct 16, 2024
f1fbec5
fix undefined variable
ashors1 Oct 17, 2024
8161076
move function
ashors1 Oct 17, 2024
994719e
Apply isort and black reformatting
ashors1 Oct 17, 2024
ea51ab2
fix adapter meta file path
cuichenx Oct 17, 2024
871ac85
Apply isort and black reformatting
cuichenx Oct 17, 2024
f5889ca
Merge branch 'refs/heads/main' into ashors/ckpt-subdirs
cuichenx Oct 17, 2024
df2c4b1
Merge remote-tracking branch 'origin/ashors/ckpt-subdirs' into ashors…
cuichenx Oct 17, 2024
5aec05b
fix mixtral test
ashors1 Oct 18, 2024
2df54e3
fix mixtral test
ashors1 Oct 18, 2024
440a244
use function for weights subdir
cuichenx Oct 18, 2024
b2883a1
address comments
ashors1 Oct 18, 2024
26a8d8d
move asserts
ashors1 Oct 18, 2024
ac1779b
fix undefined vars
ashors1 Oct 21, 2024
f380df7
bug fix
ashors1 Oct 21, 2024
e15cafa
fix mixtral test
ashors1 Oct 22, 2024
3271f5c
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/ckpt-subdirs
ashors1 Oct 25, 2024
a14151c
fix fabric
ashors1 Oct 28, 2024
ce5da6f
fix typo
ashors1 Oct 29, 2024
2cdd672
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/ckpt-subdirs
ashors1 Oct 29, 2024
0954f6b
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/ckpt-subdirs
ashors1 Oct 30, 2024
e997e24
revert some unnecessary changes
ashors1 Oct 30, 2024
7ef6ab7
Apply isort and black reformatting
ashors1 Oct 30, 2024
d20d4a8
Merge branch 'main' into ashors/ckpt-subdirs
cuichenx Oct 31, 2024
9c3fa3c
fix for peft
cuichenx Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions nemo/collections/common/parts/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import json
import subprocess
import os
import shlex

from pathlib import Path
from functools import lru_cache
from omegaconf import OmegaConf, DictConfig
import subprocess
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path

import nemo_run as run
from nemo_run.core.tunnel import LocalTunnel, SSHTunnel
from nemo_run.config import NEMORUN_HOME
from nemo_run.core.execution.docker import DockerExecutor
from nemo_run.core.execution.slurm import SlurmJobDetails
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.core.tunnel import LocalTunnel, SSHTunnel
from omegaconf import DictConfig, OmegaConf

from nemo.utils import logging


@lru_cache(maxsize=2)
def get_tunnel(**ssh_tunnel):
return SSHTunnel(**ssh_tunnel)
Expand Down Expand Up @@ -89,6 +89,7 @@ def get_mounts_from_config(cluster_config: dict, env_vars: dict = None):

return mounts


def check_if_mounted(cluster_config, path_to_check):
"""Will check that path_to_check is referenced inside one of the mounts."""
for mount in get_mounts_from_config(cluster_config) + ['/nemo_run/code:/nemo_run/code']:
Expand Down Expand Up @@ -330,8 +331,6 @@ def get_mounted_filepath(cluster_config: dict, filepath: str):
return filepath




def get_env_variables(cluster_config):
"""
Will get the environment variables from the cluster config and the user environment.
Expand Down Expand Up @@ -570,7 +569,6 @@ def add_task(
)



def run_exp(exp, cluster_config, sequential=False):
if cluster_config['executor'] == 'local':
# locally we are always running sequentially - does that need to be changed?
Expand Down
2 changes: 0 additions & 2 deletions nemo/collections/llm/t5/model/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@

def t5_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
from megatron.core import parallel_state
from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import AttnMaskType
from nemo.collections.nlp.modules.common.megatron.utils import build_attention_mask_3d

from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import AttnMaskType
from nemo.collections.nlp.modules.common.megatron.utils import build_attention_mask_3d
Expand Down
7 changes: 0 additions & 7 deletions nemo/lightning/ckpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

# NeMo2 checkpoint structure is a checkpoint directory, with a WEIGHTS_PATH and CONTEXT_PATH subdirectory structure.
# WEIGHTS_PATH stores the weights while CONTEXT_PATH stores the hyper-parameters.
WEIGHTS_PATH: str = "weights"
CONTEXT_PATH: str = "context"


Expand All @@ -18,12 +17,6 @@ def idempotent_path_append(base_dir: Union[str, Path], suffix) -> Path:
return base_dir


def ckpt_to_weights_subdir(filepath: Union[str, Path]) -> Path:
"""Given an input checkpoint filepath, clean it using `ckpt_to_dir` and then return the weights subdirectory."""
base_dir = ckpt_to_dir(filepath=filepath)
return idempotent_path_append(base_dir, WEIGHTS_PATH)


def ckpt_to_context_subdir(filepath: Union[str, Path]) -> Path:
"""Given an input checkpoint filepath, clean it using `ckpt_to_dir` and then return the context subdirectory."""
base_dir = ckpt_to_dir(filepath=filepath)
Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from filelock import FileLock, Timeout
from pytorch_lightning.trainer.states import TrainerFn

from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
nemo.lightning.ckpt_utils
begins an import cycle.

# Dynamically inherit from the correct Path subclass based on the operating system.
if os.name == 'nt':
Expand Down Expand Up @@ -182,7 +182,7 @@
trainer.strategy.setup(trainer)
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
trainer.save_checkpoint(ckpt_to_weights_subdir(output_path))
trainer.save_checkpoint(output_path)

from nemo.lightning.io.pl import TrainerContext
from nemo.utils.get_rank import is_global_rank_zero
Expand Down
26 changes: 25 additions & 1 deletion nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nemo.lightning.io.capture import IOProtocol
from nemo.lightning.io.mixin import IOMixin


try:
from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO
except ImportError:
Expand All @@ -39,6 +40,10 @@
LightningModuleT = TypeVar("LightningModuleT", bound=pl.LightningModule)
ModuleT = TypeVar("ModuleT", bound=nn.Module)

# NeMo2 checkpoint structure is a checkpoint directory, with a WEIGHTS_PATH and CONTEXT_PATH subdirectory structure.
# WEIGHTS_PATH stores the weights while CONTEXT_PATH stores the hyper-parameters.
WEIGHTS_PATH: str = "weights"


@dataclass
class TrainerContext(IOMixin, Generic[LightningModuleT]):
Expand Down Expand Up @@ -95,6 +100,20 @@
self._save_sharded_strategy = None
self.validated_consistency = False

def ckpt_to_weights_subdir(self, filepath: Union[str, Path]) -> Path:
"""Given an input checkpoint filepath, clean it using `ckpt_to_dir` and then return the weights subdirectory, if it exists."""
from nemo.lightning.resume import AdapterPath
Fixed Show fixed Hide fixed

base_dir = ckpt_to_dir(filepath=filepath)
assert isinstance(base_dir, Path)
if base_dir.parts[-1] != WEIGHTS_PATH:
maybe_base_dir = base_dir / WEIGHTS_PATH
if maybe_base_dir.is_dir():
base_dir = maybe_base_dir
if isinstance(base_dir, AdapterPath) and base_dir.base_model_path.parts[-1] != suffix:
base_dir.base_model_path = base_dir.base_model_path / WEIGHTS_PATH
return base_dir

@override
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Expand All @@ -118,7 +137,7 @@
f" storage_options, but {storage_options=} was provided."
f" Ignoring given storage_options"
)
checkpoint_dir = ckpt_to_dir(path)
checkpoint_dir = ckpt_to_dir(filepath) / ModelCheckpoint.WEIGHTS_PATH
fs = get_filesystem(checkpoint_dir)
if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir):
logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving')
Expand Down Expand Up @@ -174,6 +193,11 @@
if not fs.isdir(path):
raise ValueError(f"Distributed checkpoints should be a directory. Found: {path}.")

# Load from ckpt_path/weights (new format) if it exists
path = self.ckpt_to_weights_subdir(path)
if isinstance(path, AdapterPath) and not path.base_model_path.exists():
path.base_model_path = path.base_model_path.parent

if self.save_ckpt_format == 'zarr' and self.load_directly_on_device:
from megatron.core.dist_checkpointing.strategies.tensorstore import TensorStoreLoadShardedStrategy

Expand Down
2 changes: 0 additions & 2 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class ModelCheckpoint(PTLModelCheckpoint):
"""

UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"
WEIGHTS_PATH = "weights"

def __init__(
self,
Expand Down Expand Up @@ -436,7 +435,6 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)

# barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
# if anything goes wrong during checkpointing, we should be able to detect that data is incomplete.
ckpt_filepath = ckpt_to_dir(filepath) / ModelCheckpoint.WEIGHTS_PATH
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
ema_callback = self._ema_callback(trainer)

Expand Down
9 changes: 1 addition & 8 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@

from nemo.core.optim.mcore_optim import McoreDistributedOptimizer
from nemo.lightning import _strategy_lib, io
from nemo.lightning.ckpt_utils import ckpt_to_weights_subdir
from nemo.lightning.megatron_parallel import (
CallbackConnector,
MegatronParallel,
Expand Down Expand Up @@ -694,13 +693,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore:
if self.lightning_module.optimizers(use_pl_optimizer=False):
sharded_state_dict["optimizer"] = [self.optimizer_sharded_state_dict(is_loading=True)]

# Load from ckpt_path/weights (new format) if it exists, otherwise load from ckpt_path (legacy format)
load_dir = ckpt_to_weights_subdir(checkpoint_path)
if not load_dir.exists():
load_dir = checkpoint_path
if isinstance(load_dir, AdapterPath) and not load_dir.base_model_path.exists():
load_dir.base_model_path = load_dir.base_model_path.parent
checkpoint = self.checkpoint_io.load_checkpoint(load_dir, sharded_state_dict=sharded_state_dict)
checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=sharded_state_dict)

return checkpoint

Expand Down
Loading