From bfd07b9dc71ba6a463f8d92e302154193ee41a52 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 18 Jun 2024 00:59:47 +0200 Subject: [PATCH] [NeMo-UX] Integrate tokenizer import into model.import_ckpt (#9485) * Integrate tokenizer import into model.import_ckpt * Apply isort and black reformatting Signed-off-by: marcromeyn * Apply isort and black reformatting Signed-off-by: marcromeyn * Fixing bug in ModelConnector.nemo_save * Apply isort and black reformatting Signed-off-by: marcromeyn * Default to ddp=pytorch inside ModelConnector * Apply isort and black reformatting Signed-off-by: marcromeyn --------- Signed-off-by: marcromeyn Co-authored-by: marcromeyn --- nemo/collections/llm/gpt/model/mistral_7b.py | 8 +- nemo/lightning/experiment.py | 122 ------------------- nemo/lightning/io/connector.py | 16 ++- nemo/lightning/io/mixin.py | 2 + nemo/lightning/pytorch/strategies.py | 22 ++-- 5 files changed, 31 insertions(+), 139 deletions(-) delete mode 100644 nemo/lightning/experiment.py diff --git a/nemo/collections/llm/gpt/model/mistral_7b.py b/nemo/collections/llm/gpt/model/mistral_7b.py index 6d895925352a..56dd0090346b 100644 --- a/nemo/collections/llm/gpt/model/mistral_7b.py +++ b/nemo/collections/llm/gpt/model/mistral_7b.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, List, Optional +import pytorch_lightning as pl import torch import torch.nn.functional as F from typing_extensions import Annotated @@ -46,9 +47,7 @@ def __init__( optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, ): - _tokenizer = tokenizer or HFMistral7BImporter("mistralai/Mistral-7B-v0.1").tokenizer - - super().__init__(config or Mistral7BConfig(), optim=optim, tokenizer=_tokenizer) + super().__init__(config or Mistral7BConfig(), optim=optim, tokenizer=tokenizer) @io.model_importer(Mistral7BModel, "hf") @@ -72,6 +71,9 @@ def apply(self, output_path: Path) -> Path: return output_path + def on_import_ckpt(self, model: pl.LightningModule): + model.tokenizer = self.tokenizer + def convert_state(self, source, target): mapping = { "model.embed_tokens.weight": "embedding.word_embeddings.weight", diff --git a/nemo/lightning/experiment.py b/nemo/lightning/experiment.py deleted file mode 100644 index 473fb29380dd..000000000000 --- a/nemo/lightning/experiment.py +++ /dev/null @@ -1,122 +0,0 @@ -import os -import sys -import time -from dataclasses import dataclass -from pathlib import Path -from typing import List, Optional, Union - -import lightning_fabric as fl -import pytorch_lightning as pl -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint - -from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION -from nemo.lightning.pytorch.callbacks import ModelCheckpoint -from nemo.utils import logging -from nemo.utils.app_state import AppState -from nemo.utils.env_var_parsing import get_envbool -from nemo.utils.exp_manager import check_explicit_log_dir -from nemo.utils.get_rank import is_global_rank_zero -from nemo.utils.mcore_logger import add_handlers_to_mcore_logger - - -@dataclass -class Experiment: - name: str - dir: Optional[str] = None - explicit_log_dir: Optional[str] = None - version: Optional[str] = None - use_datetime_version: bool = True - log_local_rank_0_only: bool = False - log_global_rank_0_only: bool = False - files_to_copy: Optional[List[str]] = None - update_logger_directory: bool = True - - def __post_init__(self): - if self.log_local_rank_0_only is True and self.log_global_rank_0_only is True: - raise ValueError( - f"Cannot set both log_local_rank_0_only and log_global_rank_0_only to True. Please set either one or neither." - ) - - def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool = False): - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - global_rank = trainer.node_rank * trainer.world_size + local_rank - logging.rank = global_rank - - if self.explicit_log_dir and isinstance(trainer, pl.Trainer): # If explicit log_dir was passed, short circuit - return check_explicit_log_dir(trainer, self.explicit_log_dir, self.dir, self.name, self.version) - - # Default dir to ./nemo_experiments if None was passed - _dir = self.dir - if self.dir is None: - _dir = str(Path.cwd() / 'nemo_experiments') - - if not self.name: - self.name = "default" - - if isinstance(trainer, pl.Trainer) and trainer.logger is not None: - if self.update_logger_directory: - logging.warning( - f'"update_logger_directory" is True. Overwriting logger "save_dir" to {_dir} and "name" to {self.name}' - ) - trainer.logger._root_dir = _dir - trainer.logger._name = self.name - - version = self.version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None) - if is_global_rank_zero(): - if self.use_datetime_version: - version = time.strftime('%Y-%m-%d_%H-%M-%S') - if resume_if_exists: - logging.warning( - "No version folders would be created under the log folder as 'resume_if_exists' is enabled." - ) - version = None - if version: - if is_global_rank_zero(): - os.environ[NEMO_ENV_VARNAME_VERSION] = version - - log_dir = Path(_dir) / Path(str(self.name)) / Path("" if version is None else str(version)) - # update app_state with log_dir, exp_dir, etc - app_state = AppState() - app_state.log_dir = log_dir - app_state.exp_dir = _dir - app_state.name = self.name - app_state.version = version - - os.makedirs(log_dir, exist_ok=True) # Cannot limit creation to global zero as all ranks write to own log file - logging.info(f'Experiments will be logged at {log_dir}') - - if isinstance(trainer, pl.Trainer): - for callback in trainer.callbacks: - if isinstance(callback, PTLModelCheckpoint): - ## TODO: make configurable - callback.dirpath = Path(log_dir / "checkpoints") # app_state.exp_dir - if callback.filename is None: - callback.filename = f'{name}--{{{callback.monitor}:.4f}}-{{epoch}}' - if callback.prefix is None: - callback.prefix = name - ModelCheckpoint.CHECKPOINT_NAME_LAST = callback.filename + '-last' - - # This is set if the env var NEMO_TESTING is set to True. - nemo_testing = get_envbool(NEMO_ENV_VARNAME_TESTING, False) - - # Handle logging to file - log_file = log_dir / f'nemo_log_globalrank-{global_rank}_localrank-{local_rank}.txt' - if self.log_local_rank_0_only is True and not nemo_testing: - if local_rank == 0: - logging.add_file_handler(log_file) - elif self.log_global_rank_0_only is True and not nemo_testing: - if global_rank == 0: - logging.add_file_handler(log_file) - else: - # Logs on all ranks. - logging.add_file_handler(log_file) - - add_handlers_to_mcore_logger() - - app_state.files_to_copy = self.files_to_copy - app_state.cmd_args = sys.argv - - return app_state - - def teardown(self): - pass diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index e90e507fe0a7..a6ab4afd6d1b 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -1,3 +1,4 @@ +import inspect import logging import os import shutil @@ -138,7 +139,7 @@ def nemo_setup(self, model: pl.LightningModule, trainer: Optional[pl.Trainer] = from nemo.lightning import MegatronStrategy, Trainer _trainer = trainer or Trainer( - devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False) + devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False, ddp="pytorch") ) _trainer.strategy.connect(model) @@ -159,7 +160,12 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer) -> None: output_path (Path): The path where the model checkpoint will be saved. trainer (pl.Trainer): The trainer with the strategy to save the model. """ - trainer.strategy.setup(trainer) + _setup_kwargs = {} + setup_signature = inspect.signature(trainer.strategy.setup) + if 'setup_optimizers' in setup_signature.parameters: + _setup_kwargs["setup_optimizers"] = False + + trainer.strategy.setup(trainer, **_setup_kwargs) trainer.save_checkpoint(output_path) def nemo_load( @@ -181,7 +187,9 @@ def nemo_load( from nemo.lightning.io.api import load_ckpt model = load_ckpt(path).model - _trainer = trainer or Trainer(devices=1, accelerator="cpu" if cpu else "gpu", strategy=MegatronStrategy()) + _trainer = trainer or Trainer( + devices=1, accelerator="cpu" if cpu else "gpu", strategy=MegatronStrategy(ddp="pytorch") + ) _trainer.strategy.connect(model) _trainer.strategy.setup_environment() @@ -208,3 +216,5 @@ def local_path(self, base_path: Optional[Path] = None) -> Path: _base = Path(NEMO_MODELS_CACHE) return _base / str(self).replace("://", "/") + + def on_import_ckpt(self, model: pl.LightningModule): ... diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index b5ee76a2fe03..62b9a165c542 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -280,6 +280,8 @@ def import_ckpt(self, path: str, overwrite: bool = False, base_path: Optional[Pa ckpt_path: Path = connector.local_path(base_path=base_path) ckpt_path = connector(ckpt_path, overwrite=overwrite) + connector.on_import_ckpt(self) + return ckpt_path @classmethod diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index b9b24ec01c9d..833a1be3905a 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -126,7 +126,7 @@ def connect(self, model: pl.LightningModule) -> None: self._mcore_config = config @override - def setup(self, trainer: pl.Trainer) -> None: + def setup(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None: assert self.accelerator is not None self.accelerator.setup(trainer) self.trainer = trainer @@ -150,7 +150,7 @@ def setup(self, trainer: pl.Trainer) -> None: self.data_sampler.connect(trainer) self._fix_progress_bar(trainer) - self.setup_megatron_parallel(trainer) + self.setup_megatron_parallel(trainer, setup_optimizers=setup_optimizers) self.setup_precision_plugin() if trainer.num_sanity_val_steps > 1 and self.pipeline_model_parallel_size > 1: @@ -205,7 +205,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader: return dataloader - def setup_megatron_parallel(self, trainer: pl.Trainer) -> None: + def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None: assert self.model is not None, "Model is not set" self.megatron_parallel = MegatronParallel( @@ -224,16 +224,16 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None: self.model.configure_optimizers, megatron_parallel=self.megatron_parallel ) - self.setup_optimizers(trainer) + if setup_optimizers: + self.setup_optimizers(trainer) - # TODO: Throw an execption if we have a mcore optimizer and no ddp_config + # TODO: Throw an execption if we have a mcore optimizer and no ddp_config + if hasattr(self.precision_plugin, "convert_optimizer"): + _optimizers = [*self.optimizers] + _optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0]) + self.optimizers = _optimizers - if hasattr(self.precision_plugin, "convert_optimizer"): - _optimizers = [*self.optimizers] - _optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0]) - self.optimizers = _optimizers - - _optimizers_to_device(self.optimizers, self.root_device) + _optimizers_to_device(self.optimizers, self.root_device) self.model = self.megatron_parallel